Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / ref_rnn.cpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include <stdlib.h>
18
19 #include "src/common/mkldnn_thread.hpp"
20
21 #include "rnn/rnn.hpp"
22 #include "rnn/rnn_aux.hpp"
23
24 namespace rnn {
25
26 #define min(a, b) ((a < b) ? a : b)
27 #define max(a, b) ((a > b) ? a : b)
28 #define xstr(a) str(a)
29 #define str(a) #a
30
31 #define AOC array_offset_calculator
32
33 void lstm_activation(int dic, int n_gates, int batch,
34         //    float a[batch][n_gates * wc]
35         float *a) {
36     AOC<float> pa(a, batch, n_gates, dic);
37     mkldnn::impl::parallel_nd(batch, [&](int ib) {
38         for (int ih = 0; ih < dic; ih++) {
39             pa(ib, 0, ih) = logistic(pa(ib, 0, ih));
40             pa(ib, 1, ih) = logistic(pa(ib, 1, ih));
41             pa(ib, 2, ih) = tanhf(pa(ib, 2, ih));
42             pa(ib, 3, ih) = logistic(pa(ib, 3, ih));
43             for (int ig = 0; ig < 4; ig++) {
44                 print(80, "activation 1 a[%d][%d][%d] = %.7f\n", ib, ig, ih,
45                         pa(ib, ig, ih));
46             }
47         }
48     });
49 }
50
51 float activation(activation_t f, float x, bool is_fwd = true) {
52     float result = 0;
53     switch (f) {
54     case RELU: result = is_fwd ? relu(x) : drelu(x); break;
55     case LOGISTIC: result = is_fwd ? logistic(x) : x_m_square(x); break;
56     case TANH: result = is_fwd ? tanhf(x) : one_m_square(x); break;
57     default: assert(!"unknown activation");
58     }
59     return result;
60 }
61
62 void rnn_fwd(activation_t f, int sic, int slc, int dic, int wc, int batch,
63         int n_gates, float *dst_iter_h_, float *gates_,
64         const float *weights_layer_, const float *weights_iter_h_,
65         const float *bias_, const float *src_layer_, const float *src_iter_h_) {
66     AOC<float> dst_iter_h(dst_iter_h_, batch, n_gates, wc);
67     AOC<const float> bias(bias_, n_gates, dic);
68     AOC<float> gates(gates_, batch, n_gates, dic);
69
70     gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
71             weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
72     gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
73             weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
74
75     for (int i = 0; i < batch; i++)
76         for (int j = 0; j < n_gates; j++)
77             for (int k = 0; k < dic; k++) {
78                 const auto tmp = activation(f, gates(i, j, k) + bias(j, k));
79                 gates(i, j, k) = tmp;
80                 dst_iter_h(i, j, k) = tmp;
81             }
82 }
83
84 void gru_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
85         float *dst_iter_h_, float *gates_, const float *weights_layer_,
86         const float *weights_iter_h_, const float *bias_,
87         const float *src_layer_, const float *src_iter_h_) {
88     AOC<const float> src_iter_h(src_iter_h_, batch, wc);
89     AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
90     AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
91     AOC<const float> bias(bias_, n_gates, dic);
92     AOC<float> gates(gates_, batch, n_gates, dic);
93     AOC<float> h_dst(dst_iter_h_, batch, wc);
94
95     gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
96             weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
97     gemm("C", "N", "N", batch, (n_gates - 1) * dic, sic, 1.0, src_iter_h_,
98             wc, weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
99     for (int i = 0; i < batch; i++)
100         for (int j = 0; j < n_gates - 1; j++)
101             for (int k = 0; k < dic; k++) {
102                 gates(i, j, k) = logistic(gates(i, j, k) + bias(j, k));
103             }
104
105     for (int i = 0; i < batch; i++)
106         for (int k = 0; k < dic; k++) {
107             h_dst(i, k) = src_iter_h(i, k) * gates(i, 1, k);
108         }
109
110     gemm("C", "N", "N", batch, dic, sic, 1.0, dst_iter_h_, wc,
111             &(weights_iter_h(0, 2, 0)), n_gates * dic, 1.0, &(gates(0, 2, 0)),
112             n_gates * dic);
113
114     for (int i = 0; i < batch; i++)
115         for (int k = 0; k < dic; k++) {
116             gates(i, 2, k) = tanhf(gates(i, 2, k) + bias(2, k));
117         }
118
119     for (int i = 0; i < batch; i++)
120         for (int k = 0; k < dic; k++) {
121             h_dst(i, k) = gates(i, 0, k) * src_iter_h(i, k) +
122                 (1 - gates(i, 0, k)) * gates(i, 2, k);
123         }
124 }
125
126 void gru_lbr_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
127         float *dst_iter_h_, float *gates_, const float *weights_layer_,
128         const float *weights_iter_h_, const float *bias_,
129         const float *src_layer_, const float *src_iter_h_,
130         float *ws_local_) {
131     AOC<const float> src_iter_h(src_iter_h_, batch, wc);
132     AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
133     AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
134     AOC<const float> bias(bias_, n_gates + 1, dic);
135     AOC<float> gates(gates_, batch, n_gates, dic);
136     AOC<float> h_dst(dst_iter_h_, batch, wc);
137     AOC<float> tmp_ws(ws_local_, batch, n_gates, dic);
138
139     gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0,  src_layer_, wc,
140             weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
141
142     gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
143             weights_iter_h_, n_gates * dic, 0.0, ws_local_, n_gates * dic);
144
145     for (int i = 0; i < batch; i++)
146         for (int j = 0; j < n_gates - 1; j++)
147             for (int k = 0; k < dic; k++) {
148                 gates(i, j, k) = logistic(gates(i, j, k) + tmp_ws(i, j, k)
149                     + bias(j, k));
150             }
151
152     for (int i = 0; i < batch; i++)
153         for (int k = 0; k < dic; k++) {
154             gates(i, 2, k) = tanhf(gates(i, 2, k) + gates(i, 1, k) * (tmp_ws(i, 2, k)
155                 + bias(3, k)) + bias(2, k));
156         }
157
158     for (int i = 0; i < batch; i++)
159         for (int k = 0; k < dic; k++) {
160             h_dst(i, k) = gates(i, 0, k) * src_iter_h(i, k) +
161                 (1 - gates(i, 0, k)) * gates(i, 2, k);
162         }
163
164 }
165
166 // w = [weights_layer | weights_iter] : with order f, i , o, \bar(c)
167 void lstm_fwd(const rnn_prb_t *p, int sic, int slc, int dic, int wc, int batch,
168         int n_gates, float *dst_iter_h_, float *c_dst_, float *gates_,
169         const float *weights_layer_, const float *weights_iter_h_,
170         const float *bias_, const float *src_layer_, const float *src_iter_h_,
171         const float *src_iter_c_) {
172     AOC<float> h_dst(dst_iter_h_, batch, wc);
173     AOC<float> c_dst(c_dst_, batch, wc);
174     AOC<const float> bias(bias_, n_gates, dic);
175     AOC<const float> src_iter_c(src_iter_c_, batch, wc);
176     AOC<float> gates(gates_, batch, n_gates, dic);
177
178     const int ohi = 0;
179     const int ohf = 1;
180     const int ohc = 2;
181     const int oho = 3;
182
183     gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
184             weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
185     gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
186             weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
187
188     auto maybe_deq_w = [&](float g, int oc) {
189         if (p->cfg == conf_f32)
190             return g;
191         float scale = 1.;
192         if (p->scale_policy == PER_OC)
193             scale = p->wei_oc_scales[oc];
194         else if (p->scale_policy == COMMON)
195             scale = p->wei_scale;
196         scale *= p->data_scale;
197         return g / scale;
198     };
199
200     // add bias
201     for (int i = 0; i < batch; i++)
202         for (int j = 0; j < n_gates; j++)
203             for (int k = 0; k < dic; k++) {
204                 gates(i, j, k)
205                         = maybe_deq_w(gates(i, j, k), j * dic + k) + bias(j, k);
206             }
207
208     // run the eltwise
209     lstm_activation(dic, n_gates, batch, gates_);
210
211     auto maybe_q_d = [&](float h) {
212         if (p->cfg == conf_f32)
213             return h;
214         float fp = p->data_scale * h;
215         using R = attr_t::round_mode_t;
216         switch (p->attr.irmode) {
217         case R::DOWN: fp = floorf(fp); break;
218         case R::NEAREST: fp = nearbyintf(fp); break;
219         default: assert(!"unkown round mode");
220         }
221         if (fp + p->data_shift > p->cfg[input].max)
222             fp = p->cfg[input].max - p->data_shift;
223         if (fp + p->data_shift < p->cfg[input].min)
224             fp = p->cfg[input].min - p->data_shift;
225         return fp;
226     };
227
228     // compute C_t_l and H_t_l
229     for (int i = 0; i < batch; i++)
230         for (int j = 0; j < dic; j++) {
231             float tmp = gates(i, ohf, j) * src_iter_c(i, j)
232                     + gates(i, ohi, j) * gates(i, ohc, j);
233             c_dst(i, j) = tmp;
234             h_dst(i, j) = maybe_q_d(gates(i, oho, j) * tanhf(tmp));
235         }
236 }
237
238 void rnn_cell_fwd(const rnn_prb_t *p, alg_t alg, activation_t f, int sic,
239         int slc, int dic, int wc, int batch, int n_gates, float *dst_iter_h,
240         float *dst_iter_c, float *gates, const float *weights_layer,
241         const float *weights_iter, const float *bias, const float *src_layer,
242         const float *src_iter_h, const float *src_iter_c, float *ws_local_) {
243     switch (alg) {
244     case VANILLA_GRU:
245         gru_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
246                 weights_layer, weights_iter, bias, src_layer, src_iter_h);
247         break;
248     case LBR_GRU:
249         gru_lbr_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
250                 weights_layer, weights_iter, bias, src_layer, src_iter_h,
251                 ws_local_);
252         break;
253     case VANILLA_LSTM:
254         lstm_fwd(p, sic, slc, dic, wc, batch, n_gates, dst_iter_h, dst_iter_c,
255                 gates, weights_layer, weights_iter, bias, src_layer, src_iter_h,
256                 src_iter_c);
257         break;
258     case VANILLA_RNN:
259         rnn_fwd(f, sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
260                 weights_layer, weights_iter, bias, src_layer, src_iter_h);
261         break;
262     default: break;
263     }
264 }
265
266 void copy(int dimc, int dimr, int ld_src, int ld_dst, const float *src_,
267         float *dst_, rnn_action_t action = action_copy) {
268     AOC<const float> src(src_, dimc, ld_src);
269     AOC<float> dst(dst_, dimc, ld_dst);
270
271     mkldnn::impl::parallel_nd(dimc, [&](int i) {
272         for (int j = 0; j < dimr; j++) {
273             dst(i, j) = action == action_sum
274                     ? dst(i, j) + src(i, j) : src(i, j);
275         }
276     });
277 }
278
279 void shift(int dimc, int dimr, int ld_src, float *src_, float shift,
280         bool round = false, const rnn_prb_t *p = nullptr) {
281     AOC<float> src(src_, dimc, ld_src);
282     mkldnn::impl::parallel_nd(dimc, [&](int i) {
283         for (int j = 0; j < dimr; j++) {
284             float fp = src(i, j) + shift;
285             if (round) {
286                 using R = attr_t::round_mode_t;
287                 switch (p->attr.irmode) {
288                 case R::DOWN: fp = floorf(fp); break;
289                 case R::NEAREST: fp = nearbyintf(fp); break;
290                 default: assert(!"unkown round mode");
291                 }
292                 if (fp > UINT8_MAX)
293                     fp = UINT8_MAX;
294                 if (fp < 0)
295                     fp = 0;
296             }
297             src(i, j) = fp;
298         }
299     });
300 }
301
302 void scale(int dimc, int dimr, int ld_src, float *src_, float scale,
303         bool round = false, const rnn_prb_t *p = nullptr) {
304     AOC<float> src(src_, dimc, ld_src);
305     mkldnn::impl::parallel_nd(dimc, [&](int i) {
306         for (int j = 0; j < dimr; j++) {
307             float fp = src(i, j) * scale;
308             if (round) {
309                 using R = attr_t::round_mode_t;
310                 switch (p->attr.irmode) {
311                 case R::DOWN: fp = floorf(fp); break;
312                 case R::NEAREST: fp = nearbyintf(fp); break;
313                 default: assert(!"unkown round mode");
314                 }
315             }
316             src(i, j) = fp;
317         }
318     });
319 }
320
321 /* lstm example:
322  * fwd: ws keeps {h, c} for every cell
323  */
324 void copy_init_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
325         int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
326         int n_states, float *ws_, const float *src_layer_,
327         const float *firstit_states_, rnn_iter_direction_t iter_dir,
328         rnn_layer_direction_t lay_dir, int dir_val) {
329     AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch * wc);
330     AOC<const float> src_layer(src_layer_, n_iter, batch * slc);
331     AOC<const float> firstit_states(
332             firstit_states_, n_layer, n_dir, n_states, batch * sic);
333
334     int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
335     int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
336     bool is_int8 = p->cfg[input].dt == mkldnn_u8;
337
338     // Copy input
339     for (int it = 0; it < n_iter; it++) {
340         copy(batch, slc, slc, wc, &src_layer(it, 0),
341                 &ws(lay_dest, dir_val, it + 1, H, 0));
342         if (p->cfg[input].dt == mkldnn_u8)
343             // shift u8 input to s8 to avoid compensation in gemm
344             shift(batch, slc, wc, &ws(lay_dest, dir_val, it + 1, H, 0),
345                     -1. * p->data_shift);
346     }
347
348     // Copy states
349     for (int lay = 0; lay < n_layer; lay++) {
350         copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, H, 0),
351                 &ws(lay + 1, dir_val, it_dest, H, 0));
352         if (p->cfg[states].dt == mkldnn_u8)
353             shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
354                     -1. * p->data_shift);
355         else if (p->cfg[states].dt == mkldnn_f32 && is_int8) {
356             // quantize to s8
357             scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
358                     p->data_scale, true, p);
359         }
360
361         if (alg == VANILLA_LSTM) {
362             copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, C, 0),
363                     &ws(lay + 1, dir_val, it_dest, C, 0));
364             if (p->cfg[states].dt == mkldnn_u8) {
365                 // dequantize to f32
366                 shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
367                         -1. * p->data_shift);
368                 scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
369                         1. / p->data_scale);
370             }
371         }
372     }
373 }
374
375 /* lstm example:
376  * bwd: wsb keeps {dh, dc, dx} for every cell
377 */
378 void copy_init_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
379         int batch, int n_layer, int n_iter, int n_dir, int n_states, float *ws_,
380         const float *src_layer_, const float *firstit_states_,
381         rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
382         int dir_val, bool is_concat = false) {
383     AOC<float> ws(
384             ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch * wc);
385     auto c_stride = is_concat ? 2 * dlc : dlc;
386     AOC<const float> src_layer(src_layer_, n_iter, batch * c_stride);
387     AOC<const float> firstit_states(
388             firstit_states_, n_layer, n_dir, n_states, batch * dic);
389
390     int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
391     int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
392
393     for (int it = 0; it < n_iter; it++)
394         copy(batch, dic, c_stride, wc,
395                 &src_layer(it, dir_val * is_concat * dlc),
396                 &ws(lay_dest, dir_val, it + 1, n_states, 0));
397
398     for (int lay = 0; lay < n_layer; lay++) {
399         copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, H, 0),
400                 &ws(lay + 1, dir_val, it_dest, H, 0));
401         if (alg == VANILLA_LSTM) {
402             copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, C, 0),
403                     &ws(lay + 1, dir_val, it_dest, C, 0));
404         }
405     }
406 }
407
408 void copy_res_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
409         int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
410         int n_states, float *lastit_states_, float *lastlay_states_,
411         const float *ws_, rnn_iter_direction_t iter_dir,
412         rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action,
413         bool is_concat = false) {
414     int lastlay_c = is_concat ? 2 * dlc : dlc;
415     AOC<float> lastit_states(
416             lastit_states_, n_layer, n_dir, n_states, batch, dic);
417     AOC<float> lastlay_states(lastlay_states_, n_iter, batch, lastlay_c);
418     AOC<const float> ws(
419             ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
420
421     // Copy states layer
422     for (int it = 0; it < n_iter; it++) {
423         for (int nb = 0; nb < batch; nb++) {
424             auto from = &ws(n_layer, dir_val, it + 1, H, nb, 0);
425             auto to = &lastlay_states(
426                     it, nb, action == action_concat ? dlc : 0);
427             copy(1, dlc, wc, lastlay_c, from, to, action);
428
429             if (p->cfg[dst_last_layer].dt == mkldnn_u8) {
430                 // shift s8 internal ws to u8
431                 shift(1, dlc, lastlay_c, to, p->data_shift);
432             } else {
433                 // dequantize to f32
434                 scale(1, dlc, lastlay_c, to, 1. / p->data_scale);
435             }
436         }
437     }
438
439     int it_source = (iter_dir == left2right) ? n_iter : 1;
440
441     // Copy states iteration
442     for (int lay = 0; lay < n_layer; lay++) {
443         if (alg == VANILLA_LSTM) {
444             copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
445                     &lastit_states(lay, dir_val, C, 0, 0));
446             if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
447                 // quantize internal f32 ws to u8
448                 scale(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
449                         p->data_scale);
450                 shift(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
451                         p->data_shift, true, p);
452             }
453         }
454         copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
455                 &lastit_states(lay, dir_val, H, 0, 0));
456         if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
457             // shift s8 internal ws to u8
458             shift(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
459                     p->data_shift);
460         } else {
461             // dequantize to f32
462             scale(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
463                     1. / p->data_scale);
464         }
465     }
466 }
467
468 void copy_res_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
469         int batch, int n_layer, int n_iter, int n_dir, int n_states,
470         float *lastit_states_, float *lastlay_states_, const float *ws_,
471         rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
472         int dir_val, rnn_action_t action) {
473     AOC<float> lastit_states(
474             lastit_states_, n_layer, n_dir, n_states, batch, sic);
475     AOC<float> lastlay_states(lastlay_states_, n_iter, batch, slc);
476     AOC<const float> ws(
477             ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch, wc);
478     for (int it = 0; it < n_iter; it++) {
479         for (int nb = 0; nb < batch; nb++) {
480             // copy H to last layer states
481             auto from = &ws(1, dir_val, it + 1, n_states, nb, 0);
482             auto to = &lastlay_states(it, nb, 0);
483
484             copy(1, slc, wc, slc, from, to, action);
485         }
486     }
487
488     int it_source = (iter_dir == left2right) ? n_iter : 1;
489
490     for (int lay = 0; lay < n_layer; lay++) {
491         if (alg == VANILLA_LSTM) {
492             copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
493                     &lastit_states(lay, dir_val, C, 0, 0));
494         }
495         copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
496                 &lastit_states(lay, dir_val, H, 0, 0));
497     }
498 }
499
500 void rnn_linear_fwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
501         const float *src_iter_, const float *src_layer_,
502         const float *weights_layer_, const float *weights_iter_h_,
503         const float *bias_, float *dst_iter_, float *dst_layer_, float *ws_,
504         float *gates_) {
505
506     const alg_t alg = p->alg;
507     const int sic = p->sic;
508     const int slc = p->slc;
509     const int dic = p->dic;
510     const int dlc = p->dlc;
511     const int wc = max(sic, max(slc, dic));
512     bool is_lbr = p->alg == LBR_GRU;
513     bool is_concat = direction == mkldnn_bidirectional_concat;
514
515     const int batch = p->mb;
516     const int n_gates = p->n_gates();
517     const int n_states = p->n_states();
518     const int n_layer = p->n_layer;
519     const int n_iter = p->n_iter;
520     const int n_dir = p->n_directions();
521     activation_t f = p->activation;
522
523     AOC<const float> bias(bias_, n_layer, n_dir, (n_gates + is_lbr) * dic);
524     AOC<const float> weights_layer(
525             weights_layer_, n_layer, n_dir, n_gates * dic, slc);
526     AOC<const float> weights_iter(
527             weights_iter_h_, n_layer, n_dir, n_gates * dic, sic);
528     AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
529     AOC<float> gates(gates_, n_layer, n_dir, n_iter, batch, n_gates, dic);
530
531     int ws_local_size = is_lbr * batch * n_gates * dic;
532     float *ws_local_ = new float[ws_local_size];
533
534     auto process_direction = [&](rnn_iter_direction_t iter_dir,
535             rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action) {
536         // we first need to copy the initial states and input into ws
537         // it simplifies the logic in the following code
538         print(80, "rnn_linear_fwd: call copy_init dir_val = %d\n", dir_val);
539         copy_init_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
540                 n_dir, n_states, ws_, src_layer_, src_iter_, iter_dir, lay_dir,
541                 dir_val);
542
543         // We run the grid of computation
544         for (int il = 0; il < n_layer; il++) {
545             for (int it = 0; it < n_iter; it++) {
546                 print(80, "==== layer = %d iter = %d ===\n", il, it);
547                 int iter = (iter_dir == left2right) ? it + 1 : n_iter - it;
548                 int prev_iter = (iter_dir == left2right) ? iter - 1 : iter + 1;
549                 int lay = il + 1;
550                 rnn_cell_fwd(p, alg, f, sic, slc, dic, wc, batch, n_gates,
551                         &ws(lay, dir_val, iter, H, 0, 0),
552                         &ws(lay, dir_val, iter, C, 0, 0),
553                         &gates(lay - 1, dir_val, iter - 1, 0, 0, 0),
554                         &weights_layer(lay - 1, dir_val, 0, 0),
555                         &weights_iter(lay - 1, dir_val, 0, 0),
556                         &bias(lay - 1, dir_val, 0),
557                         &ws(lay - 1, dir_val, iter, H, 0, 0),
558                         &ws(lay, dir_val, prev_iter, H, 0, 0),
559                         &ws(lay, dir_val, prev_iter, C, 0, 0), ws_local_);
560             }
561         }
562
563         // Finally we copy the results to the result buffers
564         copy_res_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
565                 n_dir, n_states, dst_iter_, dst_layer_, ws_, iter_dir, lay_dir,
566                 dir_val, action, is_concat);
567     };
568
569     switch (direction) {
570     case mkldnn_unidirectional_left2right:
571         process_direction(left2right, bottom2top, 0, action_copy);
572         break;
573     case mkldnn_unidirectional_right2left:
574         process_direction(right2left, bottom2top, 0, action_copy);
575         break;
576     case mkldnn_bidirectional_sum:
577         process_direction(left2right, bottom2top, 0, action_copy);
578         process_direction(right2left, bottom2top, 1, action_sum);
579         break;
580     case mkldnn_bidirectional_concat:
581         process_direction(left2right, bottom2top, 0, action_copy);
582         process_direction(right2left, bottom2top, 1, action_concat);
583         break;
584     default: assert("unknown direction"); break;
585     }
586
587     delete[] ws_local_;
588 }
589
590 void compute_ref_fwd(const rnn_prb_t *p, dnn_mem_t &src_layer_m,
591         dnn_mem_t &src_iter_m, dnn_mem_t &weights_src_layer_m,
592         dnn_mem_t &weights_src_iter_m, dnn_mem_t &bias_m,
593         dnn_mem_t &dst_last_layer_m, dnn_mem_t &dst_last_iteration_m,
594         mkldnn_rnn_direction_t direction) {
595
596     assert(direction == mkldnn_unidirectional_left2right
597             || direction == mkldnn_unidirectional_right2left
598             || direction == mkldnn_bidirectional_sum
599             || direction == mkldnn_bidirectional_concat);
600
601     const int wc = max(p->sic, max(p->slc, p->dic));
602     int ws_size = (p->n_layer + 2) * p->n_directions() * (p->n_iter + 2)
603             * p->n_states() * p->mb * wc;
604     auto *ws = new float[ws_size];
605     int gates_size = p->n_layer * p->n_directions() * p->n_iter * p->mb
606             * p->n_gates() * p->dic;
607     auto *gates = new float[gates_size];
608
609     rnn_linear_fwd(p, direction, (float *)src_iter_m, (float *)src_layer_m,
610             (float *)weights_src_layer_m, (float *)weights_src_iter_m,
611             (float *)bias_m, (float *)dst_last_iteration_m,
612             (float *)dst_last_layer_m, ws, gates);
613
614     delete[] ws;
615     delete[] gates;
616 }
617
618 // =============================================================================
619 // ================ BACKWARD ===================================================
620 // =============================================================================
621 void rnn_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
622         int batch, int n_gates, float *diff_src_layer_, float *diff_src_iter_,
623         float *diff_weights_layer_, float *diff_weights_iter_h_,
624         float *diff_bias_, float *b_gates_, const float *src_layer_,
625         const float *src_iter_, const float *weights_layer_,
626         const float *weights_iter_h_, const float *bias_,
627         const float *dst_iter_h_, const float *gates_,
628         const float *diff_dst_layer_, const float *diff_dst_iter_h_) {
629     AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
630     AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
631     AOC<const float> gates(gates_, batch, n_gates, dic);
632     AOC<float> b_gates(b_gates_, batch, n_gates, dic);
633
634     for (int b = 0; b < batch; ++b)
635         for (int h = 0; h < dic; ++h) {
636             const float g = gates(b, 0, h);
637             const float dd = diff_dst_layer(b, h) + diff_dst_iter_h(b, h);
638             b_gates(b, 0, h) = activation(f, g, false) * dd;
639         }
640
641     gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_,
642             n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
643     gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
644             n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
645     for (int b = 0; b < batch; ++b)
646         copy(n_gates, dic, dic, dic, &b_gates(b, 0, 0), diff_bias_, action_sum);
647
648     gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
649             weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
650     gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic,
651             weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_, wc);
652 }
653
654 void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch,
655         int n_gates, float *diff_src_layer_, float *diff_src_iter_h_,
656         float *diff_src_iter_c_, float *diff_weights_layer_,
657         float *diff_weights_iter_h_, float *diff_bias_, float *b_gates_,
658         const float *src_layer_, const float *src_iter_h_,
659         const float *src_iter_c_, const float *weights_layer_,
660         const float *weights_iter_h_, const float *bias_,
661         const float *dst_iter_h_, const float *dst_iter_c_, const float *gates_,
662         const float *diff_dst_layer_, const float *diff_dst_iter_h_,
663         const float *diff_dst_iter_c_) {
664     // TODO: check sic and slc as last dimension in arrays and cycles
665     // input
666     AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
667     AOC<const float> diff_dst_iter_c(diff_dst_iter_c_, batch, wc);
668     AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
669     AOC<const float> src_iter_c(src_iter_c_, batch, wc);
670     AOC<const float> dst_iter_h(dst_iter_h_, batch, wc);
671     AOC<const float> dst_iter_c(dst_iter_c_, batch, wc);
672     AOC<const float> gates(gates_, batch, n_gates, dic);
673
674     AOC<float> diff_src_iter_c(diff_src_iter_c_, batch, wc);
675     AOC<float> b_gates(b_gates_, batch, n_gates, dic);
676
677     const int ohi = 0;
678     const int ohf = 1;
679     const int ohc = 2;
680     const int oho = 3;
681
682     for (int ib = 0; ib < batch; ib++)
683         for (int ih = 0; ih < dic; ih++) {
684             print(80, "rnn_single_bwd: ib = %d ih = %d\n", ib, ih);
685             float ho = gates(ib, oho, ih);
686             float hf = gates(ib, ohf, ih);
687             float hc = gates(ib, ohc, ih);
688             float hi = gates(ib, ohi, ih);
689             float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
690             float c = dst_iter_c(ib, ih);
691             float dho = tanhf(c) * dh;
692             b_gates(ib, oho, ih) = x_m_square(ho) * dho;
693
694             float dc_next = diff_dst_iter_c(ib, ih);
695             float dc = ho * dh * dtanhf(c) + dc_next;
696             diff_src_iter_c(ib, ih) = hf * dc;
697
698             float c_old = src_iter_c(ib, ih);
699             float dhf = c_old * dc;
700             b_gates(ib, ohf, ih) = x_m_square(hf) * dhf;
701
702             float dhi = hc * dc;
703             b_gates(ib, ohi, ih) = x_m_square(hi) * dhi;
704
705             float dhc = hi * dc;
706             b_gates(ib, ohc, ih) = one_m_square(hc) * dhc;
707         }
708
709     gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_h_, wc, b_gates_,
710             n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
711     gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
712             n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
713
714     gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic,
715             weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_h_, wc);
716     gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
717             weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
718
719     for (int i = 0; i < batch; i++)
720         for (int j = 0; j < n_gates; j++)
721             for (int k = 0; k < dic; k++)
722                 diff_bias_[j * dic + k] += b_gates(i, j, k);
723 }
724
725 void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
726         int batch, int n_gates, float *diff_src_layer_, float *diff_src_iter_,
727         float *diff_weights_layer_, float *diff_weights_iter_h_,
728         float *diff_bias_, float *b_gates_, const float *src_layer_,
729         const float *src_iter_, const float *weights_layer_,
730         const float *weights_iter_h_, const float *bias_,
731         const float *dst_iter_h_, const float *gates_,
732         const float *diff_dst_layer_, const float *diff_dst_iter_h_,
733         float *ws_local_) {
734
735     AOC<const float> src_iter(src_iter_, batch, wc);
736     AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
737     AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
738     AOC<const float> gates(gates_, batch, n_gates, dic);
739     AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
740     AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
741
742     AOC<float> diff_src_iter(diff_src_iter_, batch, wc);
743     AOC<float> diff_weights_iter_h(diff_weights_iter_h_, sic, n_gates, dic);
744     AOC<float> b_gates(b_gates_, batch, n_gates, dic);
745
746     float *dhr_ = ws_local_;
747     float *hr_ = ws_local_ + batch * wc;
748     AOC<float> dhr(dhr_, batch, wc);
749     AOC<float> hr(hr_, batch, wc);
750
751 // dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
752 // du = (h - u) * dh; du^ = x_m_square(u) * du;
753 // dhr = Wc dc^;
754 // dr = h * dhr; dr^ = x_m_square(r) * dr;
755     const int ohu = 0;
756     const int ohr = 1;
757     const int ohc = 2;
758     for (int ib = 0; ib < batch; ib++)
759         for (int ih = 0; ih < dic; ih++) {
760             float h = src_iter(ib, ih);
761             float c = gates(ib, ohc, ih);
762             float u = gates(ib, ohu, ih);
763             float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
764             float du = (h - c) * dh;
765             float dc = (1.0f - u) * dh;
766             b_gates(ib, ohu, ih) = x_m_square(u) * du;
767             b_gates(ib, ohc, ih) = one_m_square(c) * dc;
768             diff_src_iter(ib, ih) = dh * u;
769         }
770     gemm("C", "N", "T", batch, sic, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic,
771             &(weights_iter_h(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc);
772
773     for (int ib = 0; ib < batch; ib++)
774         for (int ih = 0; ih < dic; ih++) {
775             float h = src_iter(ib, ih);
776             float r = gates(ib, ohr, ih);
777             float dr = h * dhr(ib, ih);
778             hr(ib, ih) = h * r;
779             diff_src_iter(ib, ih) += dhr(ib, ih) * r;
780             b_gates(ib, ohr, ih) = x_m_square(r) * dr;
781         }
782
783 // dWx += xdu^ | xdr^ | xdc^
784 // dWh += hdu^ | ddr^ | (h * r)dc^
785     gemm("C", "T", "N", sic, (n_gates - 1) * dic, batch, 1.0, src_iter_, wc,
786             b_gates_, n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
787     gemm("C", "T", "N", sic, dic, batch, 1.0, hr_, wc, &(b_gates(0, 2, 0)),
788             n_gates * dic, 1.0, &(diff_weights_iter_h(0, 2, 0)), n_gates * dic);
789     gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc,
790             b_gates_, n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
791
792 // dx_next = Wxudu^ + Wxrdr^ + Wxcdc^
793 // dh_next = dh * u + Whudu^ + Whzdz^ + r * Whcdc^
794     gemm("C", "N", "T", batch, sic, (n_gates - 1)* dic, 1.0, b_gates_,
795             n_gates * dic, weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_,
796             wc);
797     gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
798             weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
799
800     for (int i = 0; i < batch; i++)
801         for (int j = 0; j < n_gates; j++)
802             for (int k = 0; k < dic; k++)
803                 diff_bias_[j * dic + k] += b_gates(i, j, k);
804 }
805
806 void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
807         int batch, int n_gates, float *diff_src_layer_, float *diff_src_iter_,
808         float *diff_weights_layer_, float *diff_weights_iter_h_,
809         float *diff_bias_, float *b_gates_, const float *src_layer_,
810         const float *src_iter_, const float *weights_layer_,
811         const float *weights_iter_h_, const float *bias_,
812         const float *dst_iter_h_, const float *gates_,
813         const float *diff_dst_layer_, const float *diff_dst_iter_h_,
814         float *ws_local_) {
815
816     AOC<const float> src_iter(src_iter_, batch, wc);
817     AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
818     AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
819     AOC<const float> gates(gates_, batch, n_gates, dic);
820     AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
821     AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
822     AOC<const float> bias(bias_, n_gates + 1, dic);
823
824     AOC<float> diff_src_iter(diff_src_iter_, batch, wc);
825     AOC<float> diff_weights_iter_h(diff_weights_iter_h_, dic, n_gates, sic);
826     AOC<float> b_gates(b_gates_, batch, n_gates, dic);
827
828     float *Wh_b_ = ws_local_;
829     float *b_gates_r_ = ws_local_ + dic * batch;
830     AOC<float> Wh_b(Wh_b_, batch, dic);
831     AOC<float> b_gates_r(b_gates_r_, batch, n_gates, dic);
832
833     for (int ib = 0; ib < batch; ib++)
834         for (int ih = 0; ih < dic; ih++)
835             Wh_b(ib, ih) = bias(3, ih);
836
837     gemm("C", "N", "N", batch, dic, sic, 1.0, src_iter_, wc,
838             &weights_iter_h(0, 2, 0), n_gates * dic, 1.0, Wh_b_, dic);
839
840
841 // dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
842 // du = (h - c) * dh; du^ = x_m_square(u) * du;
843 // dr = (Wh + b) * dc^; dr^ = x_m_square(r) * dr;
844     const int ohu = 0;
845     const int ohr = 1;
846     const int ohc = 2;
847     for (int ib = 0; ib < batch; ib++)
848         for (int ih = 0; ih < dic; ih++) {
849             float h = src_iter(ib, ih);
850             float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
851             float u = gates(ib, ohu, ih);
852             float r = gates(ib, ohr, ih);
853             float c = gates(ib, ohc, ih);
854             float du = (h - c) * dh;
855             float dc = (1.0f - u) * dh;
856
857             b_gates(ib, ohu, ih) = x_m_square(u) * du;
858             b_gates(ib, ohc, ih) = one_m_square(c) * dc;
859
860             float dr = Wh_b(ib, ih) * b_gates(ib, ohc, ih);
861             b_gates(ib, ohr, ih) = x_m_square(r) * dr;
862
863             b_gates_r(ib, ohu, ih) = b_gates(ib, ohu, ih);
864             b_gates_r(ib, ohr, ih) = b_gates(ib, ohr, ih);
865             b_gates_r(ib, ohc, ih) = b_gates(ib, ohc, ih) * r;
866             diff_src_iter(ib, ih) = dh * u;
867         }
868
869     gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_r_,
870             n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
871     gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
872             n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
873
874     gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
875             weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
876     gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_r_, n_gates * dic,
877             weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_, wc);
878
879     for (int i = 0; i < batch; i++)
880         for (int j = 0; j < n_gates; j++)
881             for (int k = 0; k < dic; k++)
882                 diff_bias_[j * dic + k] += b_gates(i, j, k);
883
884     for (int i = 0; i < batch; i++)
885         for (int k = 0; k < dic; k++)
886             diff_bias_[3 * dic + k] += b_gates_r(i, 2, k);
887 }
888
889
890 void rnn_cell_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
891         int batch, int n_gates, float *diff_src_layer, float *diff_src_iter_h,
892         float *diff_src_iter_c, float *diff_weights_layer,
893         float *diff_weights_iter, float *diff_bias, float *b_gates,
894         const float *src_layer, const float *src_iter_h,
895         const float *src_iter_c, const float *weights_layer,
896         const float *weights_iter, const float *bias, const float *dst_iter_h,
897         const float *dst_iter_c, const float *gates,
898         const float *diff_dst_layer, const float *diff_dst_iter_h,
899         const float *diff_dst_iter_c, float *ws_local_) {
900
901     switch (alg) {
902     case VANILLA_LSTM:
903         lstm_bwd(alg, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
904                 diff_src_iter_h, diff_src_iter_c, diff_weights_layer,
905                 diff_weights_iter, diff_bias, b_gates, src_layer, src_iter_h,
906                 src_iter_c, weights_layer, weights_iter, bias, dst_iter_h,
907                 dst_iter_c, gates, diff_dst_layer, diff_dst_iter_h,
908                 diff_dst_iter_c);
909         break;
910     case VANILLA_RNN:
911         rnn_bwd(alg, f, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
912                 diff_src_iter_h, diff_weights_layer, diff_weights_iter,
913                 diff_bias, b_gates, src_layer, src_iter_h, weights_layer,
914                 weights_iter, bias, dst_iter_h, gates, diff_dst_layer,
915                 diff_dst_iter_h);
916         break;
917     case VANILLA_GRU:
918         gru_bwd(alg, f, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
919                 diff_src_iter_h, diff_weights_layer, diff_weights_iter,
920                 diff_bias, b_gates, src_layer, src_iter_h, weights_layer,
921                 weights_iter, bias, dst_iter_h, gates, diff_dst_layer,
922                 diff_dst_iter_h, ws_local_);
923         break;
924     case LBR_GRU:
925         gru_lbr_bwd(alg, f, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
926                 diff_src_iter_h, diff_weights_layer, diff_weights_iter,
927                 diff_bias, b_gates, src_layer, src_iter_h, weights_layer,
928                 weights_iter, bias, dst_iter_h, gates, diff_dst_layer,
929                 diff_dst_iter_h, ws_local_);
930     default: break;
931     }
932 }
933
934 void rnn_linear_bwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
935         const float *diff_dst_iter_, const float *diff_dst_layer_,
936         const float *weights_layer_, const float *weights_iter_h_,
937         const float *bias_, float *diff_src_iter_, float *diff_src_layer_,
938         float *diff_weights_layer_, float *diff_weights_iter_h_,
939         float *diff_bias_, float *ws_, const float *gates_) {
940
941     const alg_t alg = p->alg;
942     const int sic = p->sic;
943     const int slc = p->slc;
944     const int dic = p->dic;
945     const int dlc = p->dlc;
946     const int wc = max(sic, max(slc, dic));
947     bool is_lbr = p->alg == LBR_GRU;
948
949     const int batch = p->mb;
950     const int n_gates = p->n_gates();
951     const int n_states = p->n_states();
952     const int n_layer = p->n_layer;
953     const int n_iter = p->n_iter;
954     const int n_dir = p->n_directions();
955     activation_t f = p->activation;
956
957     const int X = n_states;
958
959     AOC<const float> bias(bias_, n_layer, n_dir, n_gates + is_lbr, dic);
960     AOC<float> diff_bias(diff_bias_, n_layer, n_dir, n_gates + is_lbr, dic);
961
962     AOC<const float> weights_layer(
963             weights_layer_, n_layer, n_dir, n_gates * dic, slc);
964     AOC<const float> weights_iter(
965             weights_iter_h_, n_layer, n_dir, n_gates * dic, sic);
966
967     AOC<float> diff_weights_layer(
968             diff_weights_layer_, n_layer, n_dir, n_gates * dic, slc);
969     AOC<float> diff_weights_iter(
970             diff_weights_iter_h_, n_layer, n_dir, n_gates * dic, sic);
971
972     auto *b_gates = new float[batch * n_gates * dic];
973     AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
974     AOC<const float> gates(gates_, n_layer, n_dir, n_iter, batch, n_gates, dic);
975
976     int wsb_size = (n_layer + 2) * n_dir * (n_iter + 2) * (n_states + 1) * batch
977             * wc;
978     auto *wsb_ = new float[wsb_size];
979     init_buffer(wsb_, wsb_size, 0.); // ??!! Temporary. For debug.
980     // n_states + 1  -- H, C, X
981     AOC<float> wsb(
982             wsb_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch, wc);
983
984     int ws_local_size;
985     switch (p->alg) {
986         case LBR_GRU:
987             ws_local_size = batch * (n_gates + 1) * dic;
988             break;
989         case VANILLA_GRU:
990             ws_local_size = 2 * batch * wc;
991             break;
992         default: ws_local_size = 0;
993     }
994     float *ws_local_ = new float[ws_local_size];
995
996     auto process_direction = [&](rnn_iter_direction_t iter_dir,
997             rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action) {
998         // we first need to copy the initial states and input into ws
999         // it simplifies the logic in the following code
1000         copy_init_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
1001                 n_dir, n_states, wsb_, diff_dst_layer_, diff_dst_iter_,
1002                 iter_dir, lay_dir, dir_val,
1003                 direction == mkldnn_bidirectional_concat);
1004
1005         // We run the grid of computation
1006         for (int j = n_layer - 1; j >= 0; j--) {
1007             for (int i = 0; i < n_iter; i++) {
1008                 int iter = (iter_dir == left2right) ? i + 1 : n_iter - i;
1009                 int prev_iter = (iter_dir == left2right) ? iter - 1 : iter + 1;
1010                 int lay = j + 1;
1011                 int prev_lay = lay + 1;
1012
1013                 int ws_iter = (iter_dir == left2right) ? iter : iter;
1014                 int ws_prev_iter
1015                         = (iter_dir == left2right) ? iter + 1 : iter - 1;
1016
1017                 rnn_cell_bwd(alg, f, sic, slc, dic, wc, batch, n_gates,
1018                         &wsb(lay, dir_val, iter, X, 0, 0),
1019                         &wsb(lay, dir_val, iter, H, 0, 0),
1020                         &wsb(lay, dir_val, iter, C, 0, 0),
1021                         &diff_weights_layer(lay - 1, dir_val, 0, 0),
1022                         &diff_weights_iter(lay - 1, dir_val, 0, 0),
1023                         &diff_bias(lay - 1, dir_val, 0, 0), b_gates,
1024                         &ws(lay - 1, dir_val, ws_iter, H, 0, 0),
1025                         &ws(lay, dir_val, ws_prev_iter, H, 0, 0),
1026                         &ws(lay, dir_val, ws_prev_iter, C, 0, 0),
1027                         &weights_layer(lay - 1, dir_val, 0, 0),
1028                         &weights_iter(lay - 1, dir_val, 0, 0),
1029                         &bias(lay - 1, dir_val, 0, 0),
1030                         &ws(lay, dir_val, ws_iter, H, 0, 0),
1031                         &ws(lay, dir_val, ws_iter, C, 0, 0),
1032                         &gates(lay - 1, dir_val, ws_iter - 1, 0, 0, 0),
1033                         &wsb(prev_lay, dir_val, iter, X, 0, 0),
1034                         &wsb(lay, dir_val, prev_iter, H, 0, 0),
1035                         &wsb(lay, dir_val, prev_iter, C, 0, 0),
1036                         ws_local_);
1037             }
1038         }
1039
1040         // Finally we copy the results to the result buffers
1041         copy_res_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_dir,
1042                 n_states, diff_src_iter_, diff_src_layer_, wsb_, iter_dir,
1043                 lay_dir, dir_val, action);
1044     };
1045
1046     switch (direction) {
1047     case mkldnn_unidirectional_left2right:
1048         process_direction(right2left, top2bottom, 0, action_copy);
1049         break;
1050     case mkldnn_unidirectional_right2left:
1051         process_direction(left2right, top2bottom, 0, action_copy);
1052         break;
1053     case mkldnn_bidirectional_sum:
1054         process_direction(right2left, top2bottom, 0, action_copy);
1055         process_direction(left2right, top2bottom, 1, action_sum);
1056         break;
1057     case mkldnn_bidirectional_concat:
1058         process_direction(right2left, top2bottom, 0, action_copy);
1059         process_direction(left2right, top2bottom, 1, action_sum);
1060         break;
1061     default: assert("unknown direction"); break;
1062     }
1063
1064     delete[] wsb_;
1065     delete[] b_gates;
1066     delete[] ws_local_;
1067 }
1068
1069 void compute_ref_bwd(const rnn_prb_t *p, dnn_mem_t &input_m,
1070         dnn_mem_t &states_m, dnn_mem_t &diff_last_layer_m,
1071         dnn_mem_t &diff_last_iteration_m, dnn_mem_t &weights_input_m,
1072         dnn_mem_t &weights_states_m, dnn_mem_t &bias_m,
1073         dnn_mem_t &dst_last_layer_m, dnn_mem_t &dst_last_iteration_m,
1074         dnn_mem_t &dst_diff_input_m, dnn_mem_t &dst_diff_states_m,
1075         dnn_mem_t &dst_diff_weights_input_m,
1076         dnn_mem_t &dst_diff_weights_states_m, dnn_mem_t &dst_diff_bias_m,
1077         mkldnn_rnn_direction_t direction) {
1078     // !! TODO: add support of strides
1079
1080     assert(direction == mkldnn_unidirectional_left2right
1081             || direction == mkldnn_unidirectional_right2left
1082             || direction == mkldnn_bidirectional_sum
1083             || direction == mkldnn_bidirectional_concat);
1084
1085     assert(p->dlc == p->dic);
1086     int wc = max(p->sic, max(p->slc, p->dic));
1087     int ws_size = (p->n_layer + 2) * p->n_directions() * (p->n_iter + 2)
1088             * p->n_states() * p->mb * wc;
1089     auto *ws = new float[ws_size];
1090     init_buffer(ws, ws_size, -55.); // ??!! Temporary. For debug.
1091     int gates_size = p->n_layer * p->n_directions() * p->n_iter * p->mb
1092             * p->n_gates() * p->dic;
1093     auto *gates = new float[gates_size];
1094
1095     rnn_linear_fwd(p, direction, (float *)states_m, (float *)input_m,
1096             (float *)weights_input_m, (float *)weights_states_m,
1097             (float *)bias_m, (float *)dst_last_iteration_m,
1098             (float *)dst_last_layer_m, ws, gates);
1099
1100     rnn_linear_bwd(p, direction, (float *)diff_last_iteration_m,
1101             (float *)diff_last_layer_m, (float *)weights_input_m,
1102             (float *)weights_states_m, (float *)bias_m,
1103             (float *)dst_diff_states_m, (float *)dst_diff_input_m,
1104             (float *)dst_diff_weights_input_m,
1105             (float *)dst_diff_weights_states_m, (float *)dst_diff_bias_m, ws,
1106             gates);
1107
1108     delete[] ws;
1109 }
1110
1111 } // namespace rnn