Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / ref_rnn.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 /*
18   General architecture
19
20   for diff states, we have n_states + 1 as we have n_states diff
21   to propagate to the previous iteration and 1 states to propagate
22   to the previous layer
23   index 0 is dh for cell(t-1, l) to consume
24   index 1 is dc for cell(t-1, l) to consume
25   index 2 is dh for cell(t, l-1) to consume
26   this indexing enables to have the same indexing for states in elemwise
27   function
28   only the cell execution function should be impacted
29
30  */
31
32 #include "math_utils.hpp"
33 #include "mkldnn_thread.hpp"
34
35 #include "ref_rnn.hpp"
36 #include "../gemm/gemm.hpp"
37 #include "../simple_q10n.hpp"
38
39 namespace mkldnn {
40 namespace impl {
41 namespace cpu {
42
43 using namespace mkldnn::impl::utils;
44 using namespace mkldnn::impl::memory_tracking::names;
45 using namespace rnn_utils;
46 #define AOC array_offset_calculator
47
48 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
49 void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction(
50         const rnn_conf_t &rnn, const acc_data_t *ws_gates_,
51         float *diff_bias_) const {
52     auto body = [&](int i, int k) {
53         for (int j = 0; j < rnn.mb; j++)
54             diff_bias_[i * rnn.dic + k]
55                     += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k];
56     };
57
58     // @todo block k on simd-width
59 #if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \
60     /* icc 17.0 has a problem with simd collapse */ \
61     && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700))
62 #pragma omp parallel for simd collapse(2)
63     for (int i = 0; i < rnn.n_gates; i++)
64         for (int k = 0; k < rnn.dic; k++)
65             body(i, k);
66 #else
67     parallel_nd(rnn.n_gates, rnn.dic, body);
68 #endif
69 }
70
71 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
72 rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) {
73     assert(ldA * ldB * ldC != 0);
74     extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB,
75             &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm);
76 }
77
78 template <>
79 rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) {
80     assert(!"non packed gemm is disabled for int8");
81 }
82
83 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
84 rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) {
85 #if (USE_MKL_PACKED_GEMM)
86     assert(transA == 'N');
87     cblas_sgemm_compute(CblasColMajor, CblasPacked,
88             (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_,
89             ldB, beta, c_, ldC);
90 #else
91     UNUSED(transA);
92     UNUSED(transB);
93     UNUSED(m);
94     UNUSED(n);
95     UNUSED(k);
96     UNUSED(alpha);
97     UNUSED(ldA);
98     UNUSED(b_);
99     UNUSED(ldB);
100     UNUSED(beta);
101     UNUSED(c_);
102     UNUSED(ldC);
103     assert(!"packed gemm is disabled");
104 #endif
105 }
106
107 template <>
108 rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) {
109 #if (USE_MKL_PACKED_GEMM)
110     int8_t offseta = 0, offsetb = 0;
111     int32_t offsetc = 0;
112     cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked,
113             CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_,
114             ldB, offsetb, beta, c_, ldC, &offsetc);
115 #else
116     UNUSED(transA);
117     UNUSED(transB);
118     UNUSED(m);
119     UNUSED(n);
120     UNUSED(k);
121     UNUSED(alpha);
122     UNUSED(ldA);
123     UNUSED(b_);
124     UNUSED(ldB);
125     UNUSED(beta);
126     UNUSED(c_);
127     UNUSED(ldC);
128     assert(!"packed gemm is disabled");
129 #endif
130 }
131
132 //*************** Grid computations strategy: linear ***************//
133 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
134 rnn_grid_execution_sig(
135         (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) {
136     AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
137             rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
138     AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
139             rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
140     AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
141             (rnn.n_states + 1), rnn.n_iter + 1,
142             rnn.states_nld * rnn.states_ws_ld);
143     AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter,
144             rnn.gates_nld * rnn.gates_ws_ld);
145     AOC<weights_data_t *, 3> weights_input(
146             weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer);
147     AOC<weights_data_t *, 3> weights_states(
148             weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
149     AOC<float*, 3> bias(
150         bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
151     AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer,
152             rnn.n_dir,
153             rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld);
154     AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir,
155             rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld);
156     AOC<float, 3> diff_bias(
157             diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
158     AOC<float, 4> ws_grid(
159             ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell);
160
161     // We run the grid of computation
162     for (int dir = 0; dir < rnn.n_dir; dir++) {
163         for (int j = 0; j < rnn.n_layer; j++) {
164             int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1;
165
166             if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) {
167                 (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic,
168                         rnn.mb * rnn.n_iter, rnn.slc, 1.0,
169                         weights_input(lay, dir, 0), rnn.weights_iter_ld,
170                         &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0,
171                         &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld);
172             }
173
174             for (int i = 0; i < rnn.n_iter; i++) {
175                 int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1;
176                 (this->*cell_func)(rnn,
177                         &(ws_states(lay + 1, dir, iter + 1, 0)),
178                         &(ws_c_states(lay + 1, dir, iter + 1, 0)),
179                         &(ws_diff_states(lay, dir, 0, iter, 0)),
180                         &(weights_input(lay, dir, 0)),
181                         &(weights_states(lay, dir, 0)),
182                         &(bias(lay, dir, 0)),
183                         &(ws_states(lay, dir, iter + 1, 0)),
184                         &(ws_states(lay + 1, dir, iter, 0)),
185                         &(ws_c_states(lay + 1, dir, iter, 0)),
186                         &(ws_diff_states(lay + 1, dir, 0, iter, 0)),
187                         &(ws_diff_states(lay, dir, 0, iter + 1, 0)),
188                         &(diff_weights_layer(lay, dir, 0)),
189                         &(diff_weights_iter(lay, dir, 0)),
190                         &(diff_bias(lay, dir, 0)),
191                         &(ws_gates(lay, dir, iter, 0)),
192                         &(ws_grid(lay, dir, iter, 0)),
193                         ws_cell_);
194             }
195
196             if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) {
197                 (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter,
198                         rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0),
199                         rnn.weights_layer_ld,
200                         (src_data_t *)(&(ws_gates(lay, dir, 0, 0))),
201                         rnn.gates_ws_ld, 0.0,
202                         (acc_data_t *)(&(ws_diff_states(
203                                 lay, dir, rnn.n_states, 0, 0))),
204                         rnn.states_ws_ld);
205                 gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc,
206                         rnn.mb * rnn.n_iter, 1.0,
207                         (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
208                         rnn.gates_ws_ld,
209                         (src_data_t *)(&(ws_states(lay, dir, 1, 0))),
210                         rnn.states_ws_ld, 1.0,
211                         (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))),
212                         rnn.diff_weights_layer_ld);
213             }
214             if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) {
215                 gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic,
216                         rnn.mb * rnn.n_iter, 1.0,
217                         (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
218                         rnn.gates_ws_ld,
219                         (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))),
220                         rnn.states_ws_ld, 1.0,
221                         (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))),
222                         rnn.diff_weights_iter_ld);
223             }
224         }
225     }
226 }
227
228 //********* GRID computations strategy: utility functions **********//
229
230 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
231 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer(
232         const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
233         float *__restrict ws_diff_states_, const src_data_t *__restrict xt_,
234         const float *__restrict diff_dst_layer_) const {
235
236     AOC<src_data_t, 4> ws_states(
237             ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
238     auto xt_d = memory_desc_wrapper(pd()->src_pd(0));
239
240     parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
241         auto xxt = xt_ + xt_d.blk_off(it, b);
242         src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0));
243         src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0));
244         if (rnn.exec_dir != r2l)
245             for (int c = 0; c < rnn.slc; c++)
246                 ws_l2r_ptr[c] = xxt[c];
247         if (rnn.exec_dir != l2r)
248             for (int c = 0; c < rnn.slc; c++)
249                 ws_r2l_ptr[c] = xxt[c];
250     });
251 }
252
253 template <>
254 void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn,
255         src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_,
256         const float *diff_dst_layer_) const {
257     AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
258             (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
259     auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_pd(0));
260
261     switch (rnn.exec_dir) {
262     case bi_concat:
263         parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
264             auto diff_dst_layer_x
265                     = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
266             for (int s = 0; s < rnn.dic; s++) {
267                 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
268                         = diff_dst_layer_x[s];
269                 ws_diff_states(
270                         rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
271                         = diff_dst_layer_x[rnn.dic + s];
272             }
273         });
274         break;
275     case bi_sum:
276         parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
277             auto diff_dst_layer_x
278                     = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
279             for (int s = 0; s < rnn.dic; s++) {
280                 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
281                         = diff_dst_layer_x[s];
282                 ws_diff_states(
283                         rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
284                         = diff_dst_layer_x[s];
285             }
286         });
287         break;
288     case l2r:
289         parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
290             auto diff_dst_layer_x
291                     = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
292             for (int s = 0; s < rnn.dic; s++) {
293                 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
294                         = diff_dst_layer_x[s];
295             }
296         });
297         break;
298     case r2l:
299         parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
300             auto diff_dst_layer_x = diff_dst_layer_
301                     + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b);
302             for (int s = 0; s < rnn.dic; s++) {
303                 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
304                         = diff_dst_layer_x[s];
305             }
306         });
307         break;
308     default: assert(!"Unsupported direction"); break;
309     }
310 }
311
312 /* For int8 configuration, input iteration states may be of types f32 or u8
313  * Internally h_state is always stored in u8 and c_state is always stored in f32
314  * If input states are of type u8 then h state is copied and c state is dequantized
315  * If input states are of type f32 then h state is quantized and c_state is copied
316  * */
317 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
318 template <typename input_data_t>
319 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter(
320         const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
321         float *__restrict ws_c_states_, float *__restrict ws_diff_states_,
322         const input_data_t *__restrict firstit_states_,
323         const float *__restrict diff_dst_iter_) const {
324     AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
325             rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
326     AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
327             rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
328     float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
329     float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
330     round_mode_t rmode = pd()->attr()->round_mode_;
331
332     const bool quantize
333             = pd()->desc()->src_iter_desc.data_type == data_type::f32
334             && rnn.dt_conf != all_f32;
335     auto maybe_q = [&](input_data_t f) {
336         if (quantize) {
337             float qf = f * data_scale + data_shift;
338             return qz_a1b0<float, src_data_t>()(qf, rmode);
339         } else
340             return (src_data_t)f;
341     };
342
343     const bool dequantize
344             = pd()->desc()->src_iter_desc.data_type == data_type::u8;
345     auto maybe_deq = [&](input_data_t s) {
346         if (dequantize)
347             return (((float)s - data_shift) / data_scale);
348         else
349             return (float)s;
350     };
351     auto firstit_states_d = memory_desc_wrapper(pd()->src_pd(1));
352     if (firstit_states_) {
353         parallel_nd(
354                 rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
355                     for (int s = 0; s < rnn.sic; s++)
356                         ws_states(lay + 1, dir, 0, b, s) = maybe_q(
357                                 firstit_states_[firstit_states_d.blk_off(
358                                         lay, dir, 0, b, s)]);
359                     if (pd()->cell_kind() == alg_kind::vanilla_lstm)
360                         for (int s = 0; s < rnn.sic; s++)
361                             ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq(
362                                     firstit_states_[firstit_states_d.blk_off(
363                                             lay, dir, 1, b, s)]);
364                 });
365     } else {
366         parallel_nd(
367                 rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
368                     for (int j = 0; j < rnn.sic; j++) {
369                         ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0;
370                         ws_c_states(lay + 1, dir, 0, b, j) = 0.0f;
371                     }
372         });
373     }
374 }
375
376 template <>
377 template <typename input_data_t>
378 void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn,
379         src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_,
380         const input_data_t *firstit_states_,
381         const float *diff_dst_iter_) const {
382     AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
383             rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
384     auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_pd(1));
385     if (diff_dst_iter_) {
386         parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
387                 [&](int lay, int dir, int state, int b) {
388                     array_copy(&(ws_diff_states(
389                                        lay, dir, state, rnn.n_iter, b, 0)),
390                             diff_dst_iter_
391                                     + diff_dst_iter_d.blk_off(
392                                               lay, dir, state, b),
393                             rnn.dic);
394                 });
395     } else {
396         parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
397                 [&](int lay, int dir, int state, int i) {
398                     for (int j = 0; j < rnn.dic; j++)
399                         ws_diff_states(lay, dir, state, rnn.n_iter, i, j)
400                                 = 0.0f;
401                 });
402     }
403 }
404
405 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
406 template <typename dst_data_t>
407 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer(
408         const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer,
409         const src_data_t *ws_states_, const float *ws_diff_states_) const {
410
411     auto dst_layer_d = memory_desc_wrapper(pd()->dst_pd(0));
412     AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
413             rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
414     float shift = (pd()->attr()->rnn_data_qparams_.shift_);
415     float scale = (pd()->attr()->rnn_data_qparams_.scale_);
416
417     const bool dequantize
418             = pd()->desc()->dst_layer_desc.data_type == data_type::f32
419             && rnn.dt_conf != all_f32;
420     auto maybe_deq = [&](src_data_t s) {
421         if (dequantize)
422             return (dst_data_t)(((float)s - shift) / scale);
423         else
424             return (dst_data_t)s;
425     };
426     parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
427         int dir = 0;
428         if (rnn.exec_dir != r2l) {
429             for (int s = 0; s < rnn.dic; s++) {
430                 dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
431                         = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s));
432             }
433             dir = 1;
434         }
435         if (rnn.exec_dir != l2r) {
436             for (int s = 0; s < rnn.dic; s++)
437                 switch (rnn.exec_dir) {
438                 case bi_sum:
439                     dst_layer_[dst_layer_d.blk_off(it, b, s)]
440                             += maybe_deq(ws_states(
441                                     rnn.n_layer, dir, rnn.n_iter - it, b, s));
442                     break;
443                 default:
444                     dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
445                             = maybe_deq(ws_states(
446                                     rnn.n_layer, dir, rnn.n_iter - it, b, s));
447                 }
448         }
449     });
450 }
451
452 template <>
453 template <typename dst_data_t>
454 void ref_rnn_bwd_f32_t::copy_res_layer(
455         const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_,
456         const src_data_t *ws_states_, const float *ws_diff_states_) const {
457     auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_pd(0));
458     AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
459             rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
460             rnn.states_ws_ld);
461
462     parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
463         int dir = 0;
464         for (int s = 0; s < rnn.slc; s++) {
465             float *dst_addr = diff_src_layer_
466                     + diff_src_layer_d.blk_off(
467                               (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it,
468                               b, dir * rnn.slc + s);
469             float res = ws_diff_states(0, 0, rnn.n_states, it, b, s);
470             if (rnn.n_dir - 1)
471                 res += ws_diff_states(
472                         0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s);
473             dst_addr[0] = res;
474         }
475     });
476 }
477
478 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
479 template <typename output_data_t>
480 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter(
481         const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
482         const src_data_t *ws_states_, float *ws_c_states_,
483         const float *ws_diff_states_) const {
484     auto dst_iter_d = memory_desc_wrapper(pd()->dst_pd(1));
485     AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
486             rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
487     AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
488             rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
489     float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
490     float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
491     round_mode_t rmode = pd()->attr()->round_mode_;
492
493     const bool quantize = pd()->desc()->dst_iter_desc.data_type == data_type::u8
494             && rnn.dt_conf != all_f32;
495     auto maybe_q = [&](float f) {
496         if (quantize) {
497             float qf = f * data_scale + data_shift;
498             return qz_a1b0<float, output_data_t>()(qf, rmode);
499         } else
500             return (output_data_t)f;
501     };
502
503     const bool dequantize
504             = pd()->desc()->dst_iter_desc.data_type == data_type::f32
505             && rnn.dt_conf != all_f32;
506     auto maybe_deq = [&](src_data_t s) {
507         if (dequantize)
508             return (output_data_t)(((float)s - data_shift) / data_scale);
509         else
510             return (output_data_t)s;
511     };
512     if (dst_iter_) {
513         parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
514                 [&](int lay, int dir, int b) {
515             for (int s = 0; s < rnn.dic; s++) {
516                 dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)]
517                         = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s));
518             }
519             if (pd()->cell_kind() == alg_kind::vanilla_lstm)
520                     for (int s = 0; s < rnn.dic; s++) {
521                         dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)]
522                                 = maybe_q(ws_c_states(
523                                         lay + 1, dir, rnn.n_iter, b, s));
524                     }
525             });
526     }
527 }
528
529 template <>
530 template <typename output_data_t>
531 void ref_rnn_bwd_f32_t::copy_res_iter(
532         const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
533         const src_data_t *ws_states_, float *ws_c_states_,
534         const float *ws_diff_states_) const {
535     auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_pd(1));
536     AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
537             rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
538             rnn.states_ws_ld);
539     if (diff_src_iter_) {
540         parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
541                 [&](int lay, int dir, int state, int b) {
542                     for (int s = 0; s < rnn.sic; s++) {
543                         diff_src_iter_[diff_src_iter_d.blk_off(
544                                 lay, dir, state, b, s)]
545                                 = ws_diff_states(lay, dir, state, 0, b, s);
546                     }
547                 });
548     }
549 }
550
551 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
552 rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) {
553     /* Original set of bias provided by the user */
554     AOC<const float, 5> b(
555             b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
556     /* Array of pointers initialized in packing */
557     AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
558     AOC<float, 3> scratch_bias(
559             scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
560
561     if (rnn.copy_bias) {
562         parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic,
563                 [&](size_t i) { scratch_bias_[i] = b_[i]; });
564     }
565
566     for (int i = 0; i < rnn.n_layer; i++) {
567         for (int d = 0; d < rnn.n_dir; d++) {
568             int offset_bias = 0;
569             for (int p = 0; p < rnn.n_parts_bias; p++) {
570                 bias(i, d, p) = rnn.copy_bias
571                         ? (float *) &scratch_bias(i, d, offset_bias)
572                         : (float *) &b(i, d, offset_bias);
573                 offset_bias += rnn.parts_bias[p] * rnn.dic;
574             }
575         }
576     }
577
578 }
579
580 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
581 rnn_bias_finalize_sig(
582         (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) {
583     if (rnn.dt_conf != all_f32) {
584         float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
585         float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
586         float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
587         bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0;
588         for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++)
589             for (int j = 0; j < rnn.n_bias * rnn.dic; j++) {
590                 size_t off = i * rnn.n_bias * rnn.dic + j;
591                 float weights_scale
592                         = scale_per_oc ? weights_scales[j] : weights_scales[0];
593                 scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off])
594                         * data_shift / (weights_scale * data_scale);
595             }
596     }
597 }
598
599 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
600 rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type,
601         weights_type>::assign_packed_weights)) {
602     AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
603
604     size_t offset_packed = 0;
605     for (int l = 0; l < rnn.n_layer; l++)
606         for (int d = 0; d < rnn.n_dir; d++) {
607             for (int p = 0; p < n_parts; p++) {
608                 weights(l, d, p) = (weights_data_t *)&w_[offset_packed];
609                 offset_packed
610                         += part_weights_pack_size[p] / sizeof(weights_data_t);
611             }
612         }
613 }
614
615 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
616 rnn_weights_assign_sig(
617         (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) {
618     assert(nld * ld != 0);
619     /* Original set of weights provided by the user */
620     AOC<const weights_data_t, 3> w(w_, rnn.n_layer, rnn.n_dir, nld * ld);
621     /* Array of pointers for each part of weights */
622     AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
623
624     for (int i = 0; i < rnn.n_layer; i++)
625         for (int d = 0; d < rnn.n_dir; d++) {
626             size_t offset_weights = 0;
627             for (int p = 0; p < n_parts; p++) {
628                 weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights);
629                 offset_weights += fmt == memory_format::ldigo ?
630                         gates_per_part[p] * OC_size :
631                         gates_per_part[p] * OC_size * ld;
632             }
633         }
634 }
635
636 //********************* Execution function *********************//
637 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
638 void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_() const {
639     const rnn_conf_t &rnn = this->pd()->rnn_;
640     int input_idx = 0;
641     int output_idx = 0;
642     auto input = reinterpret_cast<const src_data_t *>(
643             this->input_memory(input_idx++));
644     auto states = pd()->with_src_iter() ? (this->input_memory(input_idx++)) :
645                                           nullptr;
646
647     const char *layer_weights_n_comp = this->input_memory(input_idx++);
648     auto w_layer
649             = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp);
650     auto w_layer_comp = reinterpret_cast<const float *>(layer_weights_n_comp
651             + rnn.weights_layer_comp_offset);
652     const char *iter_weights_n_comp = this->input_memory(input_idx++);
653     auto w_iter
654             = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp);
655     auto w_iter_comp = reinterpret_cast<const float *>(iter_weights_n_comp
656             + rnn.weights_iter_comp_offset);
657     auto bias = pd()->with_bias() ?
658             reinterpret_cast<const float *>(this->input_memory(input_idx++)) :
659             nullptr;
660
661     auto dst_last_layer = rnn.is_fwd ? this->memory(output_idx++) :
662                                        this->input_memory(input_idx++);
663     auto dst_last_iter = pd()->with_dst_iter()
664             ? (rnn.is_fwd
665                 ? this->memory(output_idx++)
666                 : this->input_memory(input_idx++))
667             : nullptr;
668
669     auto diff_dst_layer = rnn.is_fwd ?
670             nullptr :
671             reinterpret_cast<const float *>(this->input_memory(input_idx++));
672     auto diff_dst_iter = rnn.is_fwd || !pd()->with_dst_iter() ?
673             nullptr :
674             reinterpret_cast<const float *>(this->input_memory(input_idx++));
675
676     auto scratchpad = this->scratchpad();
677
678     auto ptr_wei_layer
679             = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer);
680     auto ptr_wei_iter
681             = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter);
682     auto ptr_bias =
683         scratchpad.template get<float *>(key_rnn_ptrs_bia);
684
685     // fetchihg buffers from the workspace
686     // if no workspace was provided we use the scratchpad
687     char *scratch_ptr = scratchpad.template get<char>(key_rnn_space);
688     char *ws_ptr = nullptr;
689     if (rnn.use_workspace)
690         ws_ptr = rnn.is_fwd
691             ? this->memory(output_idx++)
692             : const_cast<char *>(this->input_memory(input_idx++));
693     char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr;
694     acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_);
695     src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_);
696     float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_);
697     float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_);
698     float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_);
699     float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_);
700
701     auto diff_src_layer = rnn.is_fwd ?
702             nullptr :
703             reinterpret_cast<float *>(this->memory(output_idx++));
704     auto diff_src_iter = rnn.is_fwd || !pd()->with_src_iter() ?
705             nullptr :
706             reinterpret_cast<float *>(this->memory(output_idx++));
707     auto diff_weights_layer = rnn.is_fwd ?
708             nullptr :
709             reinterpret_cast<float *>(this->memory(output_idx++));
710     auto diff_weights_iter = rnn.is_fwd ?
711             nullptr :
712             reinterpret_cast<float *>(this->memory(output_idx++));
713     auto diff_bias = rnn.is_fwd || !pd()->with_bias() ?
714             nullptr :
715             reinterpret_cast<float *>(this->memory(output_idx++));
716
717     // Fetching extra buffers from scratchpad
718     float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_);
719
720     // initialize diff_states to 0
721     if (aprop == prop_kind::backward)
722         array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float));
723
724     /* Pack(if using packed gemm API) or copy(if input arrays have bad leading
725      * dimension */
726     (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias);
727
728     (this->*weights_iter_assign_func)(rnn, rnn.weights_iter_fmt,
729             rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic,
730             rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter,
731             rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter,
732             ptr_bias, bias, ws_bias);
733     (this->*weights_layer_assign_func)(rnn, rnn.weights_layer_fmt,
734             rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc,
735             rnn.n_parts_weights_layer, rnn.parts_weights_layer,
736             rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias,
737             bias, ws_bias);
738
739     (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp);
740
741     // we first need to copy the initial states and input into ws
742     copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer);
743     if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
744             || rnn.dt_conf == all_f32)
745         copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
746                 (const float *)states, diff_dst_iter);
747     else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
748         copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
749                 (const uint8_t *)states, diff_dst_iter);
750     else
751         assert(!"unimplemented");
752
753     // run the execution on the grid
754     (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias,
755             ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid,
756             diff_weights_layer, diff_weights_iter, diff_bias);
757
758     // Finally we copy the results to the result buffers
759     if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32
760             || rnn.dt_conf == all_f32)
761         copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states,
762                 ws_diff_states);
763     else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8)
764         copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer,
765                 ws_states, ws_diff_states);
766     else
767         assert(!"unimplemented");
768
769     if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
770             || rnn.dt_conf == all_f32)
771         copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states,
772                 ws_c_states, ws_diff_states);
773     else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
774         copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states,
775                 ws_c_states, ws_diff_states);
776     else
777         assert(!"unimplemented");
778 };
779
780 /* Fix for MSVS warning C4661 */
781 template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
782 template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
783 template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution);
784 template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
785 template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
786 template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru);
787 template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
788 template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr);
789 template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr);
790 template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise);
791 template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise);
792 template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise);
793 template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise);
794 template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise);
795 template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise);
796 template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise);
797 template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise);
798 template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise);
799
800 template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
801 template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
802 template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
803
804 #undef AOC
805 }
806 }
807 }