1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "src/common/mkldnn_thread.hpp"
21 #include "rnn/rnn.hpp"
22 #include "rnn/rnn_aux.hpp"
26 #define min(a, b) ((a < b) ? a : b)
27 #define max(a, b) ((a > b) ? a : b)
28 #define xstr(a) str(a)
31 #define AOC array_offset_calculator
33 void lstm_activation(int dic, int n_gates, int batch,
34 // float a[batch][n_gates * wc]
36 AOC<float> pa(a, batch, n_gates, dic);
37 mkldnn::impl::parallel_nd(batch, [&](int ib) {
38 for (int ih = 0; ih < dic; ih++) {
39 pa(ib, 0, ih) = logistic(pa(ib, 0, ih));
40 pa(ib, 1, ih) = logistic(pa(ib, 1, ih));
41 pa(ib, 2, ih) = tanhf(pa(ib, 2, ih));
42 pa(ib, 3, ih) = logistic(pa(ib, 3, ih));
43 for (int ig = 0; ig < 4; ig++) {
44 print(80, "activation 1 a[%d][%d][%d] = %.7f\n", ib, ig, ih,
51 float activation(activation_t f, float x, bool is_fwd = true) {
54 case RELU: result = is_fwd ? relu(x) : drelu(x); break;
55 case LOGISTIC: result = is_fwd ? logistic(x) : x_m_square(x); break;
56 case TANH: result = is_fwd ? tanhf(x) : one_m_square(x); break;
57 default: assert(!"unknown activation");
62 void rnn_fwd(activation_t f, int sic, int slc, int dic, int wc, int batch,
63 int n_gates, float *dst_iter_h_, float *gates_,
64 const float *weights_layer_, const float *weights_iter_h_,
65 const float *bias_, const float *src_layer_, const float *src_iter_h_) {
66 AOC<float> dst_iter_h(dst_iter_h_, batch, n_gates, wc);
67 AOC<const float> bias(bias_, n_gates, dic);
68 AOC<float> gates(gates_, batch, n_gates, dic);
70 gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
71 weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
72 gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
73 weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
75 for (int i = 0; i < batch; i++)
76 for (int j = 0; j < n_gates; j++)
77 for (int k = 0; k < dic; k++) {
78 const auto tmp = activation(f, gates(i, j, k) + bias(j, k));
80 dst_iter_h(i, j, k) = tmp;
84 void gru_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
85 float *dst_iter_h_, float *gates_, const float *weights_layer_,
86 const float *weights_iter_h_, const float *bias_,
87 const float *src_layer_, const float *src_iter_h_) {
88 AOC<const float> src_iter_h(src_iter_h_, batch, wc);
89 AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
90 AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
91 AOC<const float> bias(bias_, n_gates, dic);
92 AOC<float> gates(gates_, batch, n_gates, dic);
93 AOC<float> h_dst(dst_iter_h_, batch, wc);
95 gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
96 weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
97 gemm("C", "N", "N", batch, (n_gates - 1) * dic, sic, 1.0, src_iter_h_,
98 wc, weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
99 for (int i = 0; i < batch; i++)
100 for (int j = 0; j < n_gates - 1; j++)
101 for (int k = 0; k < dic; k++) {
102 gates(i, j, k) = logistic(gates(i, j, k) + bias(j, k));
105 for (int i = 0; i < batch; i++)
106 for (int k = 0; k < dic; k++) {
107 h_dst(i, k) = src_iter_h(i, k) * gates(i, 1, k);
110 gemm("C", "N", "N", batch, dic, sic, 1.0, dst_iter_h_, wc,
111 &(weights_iter_h(0, 2, 0)), n_gates * dic, 1.0, &(gates(0, 2, 0)),
114 for (int i = 0; i < batch; i++)
115 for (int k = 0; k < dic; k++) {
116 gates(i, 2, k) = tanhf(gates(i, 2, k) + bias(2, k));
119 for (int i = 0; i < batch; i++)
120 for (int k = 0; k < dic; k++) {
121 h_dst(i, k) = gates(i, 0, k) * src_iter_h(i, k) +
122 (1 - gates(i, 0, k)) * gates(i, 2, k);
126 void gru_lbr_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
127 float *dst_iter_h_, float *gates_, const float *weights_layer_,
128 const float *weights_iter_h_, const float *bias_,
129 const float *src_layer_, const float *src_iter_h_,
131 AOC<const float> src_iter_h(src_iter_h_, batch, wc);
132 AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
133 AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
134 AOC<const float> bias(bias_, n_gates + 1, dic);
135 AOC<float> gates(gates_, batch, n_gates, dic);
136 AOC<float> h_dst(dst_iter_h_, batch, wc);
137 AOC<float> tmp_ws(ws_local_, batch, n_gates, dic);
139 gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
140 weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
142 gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
143 weights_iter_h_, n_gates * dic, 0.0, ws_local_, n_gates * dic);
145 for (int i = 0; i < batch; i++)
146 for (int j = 0; j < n_gates - 1; j++)
147 for (int k = 0; k < dic; k++) {
148 gates(i, j, k) = logistic(gates(i, j, k) + tmp_ws(i, j, k)
152 for (int i = 0; i < batch; i++)
153 for (int k = 0; k < dic; k++) {
154 gates(i, 2, k) = tanhf(gates(i, 2, k) + gates(i, 1, k) * (tmp_ws(i, 2, k)
155 + bias(3, k)) + bias(2, k));
158 for (int i = 0; i < batch; i++)
159 for (int k = 0; k < dic; k++) {
160 h_dst(i, k) = gates(i, 0, k) * src_iter_h(i, k) +
161 (1 - gates(i, 0, k)) * gates(i, 2, k);
166 // w = [weights_layer | weights_iter] : with order f, i , o, \bar(c)
167 void lstm_fwd(const rnn_prb_t *p, int sic, int slc, int dic, int wc, int batch,
168 int n_gates, float *dst_iter_h_, float *c_dst_, float *gates_,
169 const float *weights_layer_, const float *weights_iter_h_,
170 const float *bias_, const float *src_layer_, const float *src_iter_h_,
171 const float *src_iter_c_) {
172 AOC<float> h_dst(dst_iter_h_, batch, wc);
173 AOC<float> c_dst(c_dst_, batch, wc);
174 AOC<const float> bias(bias_, n_gates, dic);
175 AOC<const float> src_iter_c(src_iter_c_, batch, wc);
176 AOC<float> gates(gates_, batch, n_gates, dic);
183 gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
184 weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
185 gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
186 weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
188 auto maybe_deq_w = [&](float g, int oc) {
189 if (p->cfg == conf_f32)
192 if (p->scale_policy == PER_OC)
193 scale = p->wei_oc_scales[oc];
194 else if (p->scale_policy == COMMON)
195 scale = p->wei_scale;
196 scale *= p->data_scale;
201 for (int i = 0; i < batch; i++)
202 for (int j = 0; j < n_gates; j++)
203 for (int k = 0; k < dic; k++) {
205 = maybe_deq_w(gates(i, j, k), j * dic + k) + bias(j, k);
209 lstm_activation(dic, n_gates, batch, gates_);
211 auto maybe_q_d = [&](float h) {
212 if (p->cfg == conf_f32)
214 float fp = p->data_scale * h;
215 using R = attr_t::round_mode_t;
216 switch (p->attr.irmode) {
217 case R::DOWN: fp = floorf(fp); break;
218 case R::NEAREST: fp = nearbyintf(fp); break;
219 default: assert(!"unkown round mode");
221 if (fp + p->data_shift > p->cfg[input].max)
222 fp = p->cfg[input].max - p->data_shift;
223 if (fp + p->data_shift < p->cfg[input].min)
224 fp = p->cfg[input].min - p->data_shift;
228 // compute C_t_l and H_t_l
229 for (int i = 0; i < batch; i++)
230 for (int j = 0; j < dic; j++) {
231 float tmp = gates(i, ohf, j) * src_iter_c(i, j)
232 + gates(i, ohi, j) * gates(i, ohc, j);
234 h_dst(i, j) = maybe_q_d(gates(i, oho, j) * tanhf(tmp));
238 void rnn_cell_fwd(const rnn_prb_t *p, alg_t alg, activation_t f, int sic,
239 int slc, int dic, int wc, int batch, int n_gates, float *dst_iter_h,
240 float *dst_iter_c, float *gates, const float *weights_layer,
241 const float *weights_iter, const float *bias, const float *src_layer,
242 const float *src_iter_h, const float *src_iter_c, float *ws_local_) {
245 gru_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
246 weights_layer, weights_iter, bias, src_layer, src_iter_h);
249 gru_lbr_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
250 weights_layer, weights_iter, bias, src_layer, src_iter_h,
254 lstm_fwd(p, sic, slc, dic, wc, batch, n_gates, dst_iter_h, dst_iter_c,
255 gates, weights_layer, weights_iter, bias, src_layer, src_iter_h,
259 rnn_fwd(f, sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
260 weights_layer, weights_iter, bias, src_layer, src_iter_h);
266 void copy(int dimc, int dimr, int ld_src, int ld_dst, const float *src_,
267 float *dst_, rnn_action_t action = action_copy) {
268 AOC<const float> src(src_, dimc, ld_src);
269 AOC<float> dst(dst_, dimc, ld_dst);
271 mkldnn::impl::parallel_nd(dimc, [&](int i) {
272 for (int j = 0; j < dimr; j++) {
273 dst(i, j) = action == action_sum
274 ? dst(i, j) + src(i, j) : src(i, j);
279 void shift(int dimc, int dimr, int ld_src, float *src_, float shift,
280 bool round = false, const rnn_prb_t *p = nullptr) {
281 AOC<float> src(src_, dimc, ld_src);
282 mkldnn::impl::parallel_nd(dimc, [&](int i) {
283 for (int j = 0; j < dimr; j++) {
284 float fp = src(i, j) + shift;
286 using R = attr_t::round_mode_t;
287 switch (p->attr.irmode) {
288 case R::DOWN: fp = floorf(fp); break;
289 case R::NEAREST: fp = nearbyintf(fp); break;
290 default: assert(!"unkown round mode");
302 void scale(int dimc, int dimr, int ld_src, float *src_, float scale,
303 bool round = false, const rnn_prb_t *p = nullptr) {
304 AOC<float> src(src_, dimc, ld_src);
305 mkldnn::impl::parallel_nd(dimc, [&](int i) {
306 for (int j = 0; j < dimr; j++) {
307 float fp = src(i, j) * scale;
309 using R = attr_t::round_mode_t;
310 switch (p->attr.irmode) {
311 case R::DOWN: fp = floorf(fp); break;
312 case R::NEAREST: fp = nearbyintf(fp); break;
313 default: assert(!"unkown round mode");
322 * fwd: ws keeps {h, c} for every cell
324 void copy_init_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
325 int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
326 int n_states, float *ws_, const float *src_layer_,
327 const float *firstit_states_, rnn_iter_direction_t iter_dir,
328 rnn_layer_direction_t lay_dir, int dir_val) {
329 AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch * wc);
330 AOC<const float> src_layer(src_layer_, n_iter, batch * slc);
331 AOC<const float> firstit_states(
332 firstit_states_, n_layer, n_dir, n_states, batch * sic);
334 int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
335 int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
336 bool is_int8 = p->cfg[input].dt == mkldnn_u8;
339 for (int it = 0; it < n_iter; it++) {
340 copy(batch, slc, slc, wc, &src_layer(it, 0),
341 &ws(lay_dest, dir_val, it + 1, H, 0));
342 if (p->cfg[input].dt == mkldnn_u8)
343 // shift u8 input to s8 to avoid compensation in gemm
344 shift(batch, slc, wc, &ws(lay_dest, dir_val, it + 1, H, 0),
345 -1. * p->data_shift);
349 for (int lay = 0; lay < n_layer; lay++) {
350 copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, H, 0),
351 &ws(lay + 1, dir_val, it_dest, H, 0));
352 if (p->cfg[states].dt == mkldnn_u8)
353 shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
354 -1. * p->data_shift);
355 else if (p->cfg[states].dt == mkldnn_f32 && is_int8) {
357 scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
358 p->data_scale, true, p);
361 if (alg == VANILLA_LSTM) {
362 copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, C, 0),
363 &ws(lay + 1, dir_val, it_dest, C, 0));
364 if (p->cfg[states].dt == mkldnn_u8) {
366 shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
367 -1. * p->data_shift);
368 scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
376 * bwd: wsb keeps {dh, dc, dx} for every cell
378 void copy_init_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
379 int batch, int n_layer, int n_iter, int n_dir, int n_states, float *ws_,
380 const float *src_layer_, const float *firstit_states_,
381 rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
382 int dir_val, bool is_concat = false) {
384 ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch * wc);
385 auto c_stride = is_concat ? 2 * dlc : dlc;
386 AOC<const float> src_layer(src_layer_, n_iter, batch * c_stride);
387 AOC<const float> firstit_states(
388 firstit_states_, n_layer, n_dir, n_states, batch * dic);
390 int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
391 int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
393 for (int it = 0; it < n_iter; it++)
394 copy(batch, dic, c_stride, wc,
395 &src_layer(it, dir_val * is_concat * dlc),
396 &ws(lay_dest, dir_val, it + 1, n_states, 0));
398 for (int lay = 0; lay < n_layer; lay++) {
399 copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, H, 0),
400 &ws(lay + 1, dir_val, it_dest, H, 0));
401 if (alg == VANILLA_LSTM) {
402 copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, C, 0),
403 &ws(lay + 1, dir_val, it_dest, C, 0));
408 void copy_res_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
409 int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
410 int n_states, float *lastit_states_, float *lastlay_states_,
411 const float *ws_, rnn_iter_direction_t iter_dir,
412 rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action,
413 bool is_concat = false) {
414 int lastlay_c = is_concat ? 2 * dlc : dlc;
415 AOC<float> lastit_states(
416 lastit_states_, n_layer, n_dir, n_states, batch, dic);
417 AOC<float> lastlay_states(lastlay_states_, n_iter, batch, lastlay_c);
419 ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
422 for (int it = 0; it < n_iter; it++) {
423 for (int nb = 0; nb < batch; nb++) {
424 auto from = &ws(n_layer, dir_val, it + 1, H, nb, 0);
425 auto to = &lastlay_states(
426 it, nb, action == action_concat ? dlc : 0);
427 copy(1, dlc, wc, lastlay_c, from, to, action);
429 if (p->cfg[dst_last_layer].dt == mkldnn_u8) {
430 // shift s8 internal ws to u8
431 shift(1, dlc, lastlay_c, to, p->data_shift);
434 scale(1, dlc, lastlay_c, to, 1. / p->data_scale);
439 int it_source = (iter_dir == left2right) ? n_iter : 1;
441 // Copy states iteration
442 for (int lay = 0; lay < n_layer; lay++) {
443 if (alg == VANILLA_LSTM) {
444 copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
445 &lastit_states(lay, dir_val, C, 0, 0));
446 if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
447 // quantize internal f32 ws to u8
448 scale(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
450 shift(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
451 p->data_shift, true, p);
454 copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
455 &lastit_states(lay, dir_val, H, 0, 0));
456 if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
457 // shift s8 internal ws to u8
458 shift(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
462 scale(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
468 void copy_res_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
469 int batch, int n_layer, int n_iter, int n_dir, int n_states,
470 float *lastit_states_, float *lastlay_states_, const float *ws_,
471 rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
472 int dir_val, rnn_action_t action) {
473 AOC<float> lastit_states(
474 lastit_states_, n_layer, n_dir, n_states, batch, sic);
475 AOC<float> lastlay_states(lastlay_states_, n_iter, batch, slc);
477 ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch, wc);
478 for (int it = 0; it < n_iter; it++) {
479 for (int nb = 0; nb < batch; nb++) {
480 // copy H to last layer states
481 auto from = &ws(1, dir_val, it + 1, n_states, nb, 0);
482 auto to = &lastlay_states(it, nb, 0);
484 copy(1, slc, wc, slc, from, to, action);
488 int it_source = (iter_dir == left2right) ? n_iter : 1;
490 for (int lay = 0; lay < n_layer; lay++) {
491 if (alg == VANILLA_LSTM) {
492 copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
493 &lastit_states(lay, dir_val, C, 0, 0));
495 copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
496 &lastit_states(lay, dir_val, H, 0, 0));
500 void rnn_linear_fwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
501 const float *src_iter_, const float *src_layer_,
502 const float *weights_layer_, const float *weights_iter_h_,
503 const float *bias_, float *dst_iter_, float *dst_layer_, float *ws_,
506 const alg_t alg = p->alg;
507 const int sic = p->sic;
508 const int slc = p->slc;
509 const int dic = p->dic;
510 const int dlc = p->dlc;
511 const int wc = max(sic, max(slc, dic));
512 bool is_lbr = p->alg == LBR_GRU;
513 bool is_concat = direction == mkldnn_bidirectional_concat;
515 const int batch = p->mb;
516 const int n_gates = p->n_gates();
517 const int n_states = p->n_states();
518 const int n_layer = p->n_layer;
519 const int n_iter = p->n_iter;
520 const int n_dir = p->n_directions();
521 activation_t f = p->activation;
523 AOC<const float> bias(bias_, n_layer, n_dir, (n_gates + is_lbr) * dic);
524 AOC<const float> weights_layer(
525 weights_layer_, n_layer, n_dir, n_gates * dic, slc);
526 AOC<const float> weights_iter(
527 weights_iter_h_, n_layer, n_dir, n_gates * dic, sic);
528 AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
529 AOC<float> gates(gates_, n_layer, n_dir, n_iter, batch, n_gates, dic);
531 int ws_local_size = is_lbr * batch * n_gates * dic;
532 float *ws_local_ = new float[ws_local_size];
534 auto process_direction = [&](rnn_iter_direction_t iter_dir,
535 rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action) {
536 // we first need to copy the initial states and input into ws
537 // it simplifies the logic in the following code
538 print(80, "rnn_linear_fwd: call copy_init dir_val = %d\n", dir_val);
539 copy_init_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
540 n_dir, n_states, ws_, src_layer_, src_iter_, iter_dir, lay_dir,
543 // We run the grid of computation
544 for (int il = 0; il < n_layer; il++) {
545 for (int it = 0; it < n_iter; it++) {
546 print(80, "==== layer = %d iter = %d ===\n", il, it);
547 int iter = (iter_dir == left2right) ? it + 1 : n_iter - it;
548 int prev_iter = (iter_dir == left2right) ? iter - 1 : iter + 1;
550 rnn_cell_fwd(p, alg, f, sic, slc, dic, wc, batch, n_gates,
551 &ws(lay, dir_val, iter, H, 0, 0),
552 &ws(lay, dir_val, iter, C, 0, 0),
553 &gates(lay - 1, dir_val, iter - 1, 0, 0, 0),
554 &weights_layer(lay - 1, dir_val, 0, 0),
555 &weights_iter(lay - 1, dir_val, 0, 0),
556 &bias(lay - 1, dir_val, 0),
557 &ws(lay - 1, dir_val, iter, H, 0, 0),
558 &ws(lay, dir_val, prev_iter, H, 0, 0),
559 &ws(lay, dir_val, prev_iter, C, 0, 0), ws_local_);
563 // Finally we copy the results to the result buffers
564 copy_res_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
565 n_dir, n_states, dst_iter_, dst_layer_, ws_, iter_dir, lay_dir,
566 dir_val, action, is_concat);
570 case mkldnn_unidirectional_left2right:
571 process_direction(left2right, bottom2top, 0, action_copy);
573 case mkldnn_unidirectional_right2left:
574 process_direction(right2left, bottom2top, 0, action_copy);
576 case mkldnn_bidirectional_sum:
577 process_direction(left2right, bottom2top, 0, action_copy);
578 process_direction(right2left, bottom2top, 1, action_sum);
580 case mkldnn_bidirectional_concat:
581 process_direction(left2right, bottom2top, 0, action_copy);
582 process_direction(right2left, bottom2top, 1, action_concat);
584 default: assert("unknown direction"); break;
590 void compute_ref_fwd(const rnn_prb_t *p, dnn_mem_t &src_layer_m,
591 dnn_mem_t &src_iter_m, dnn_mem_t &weights_src_layer_m,
592 dnn_mem_t &weights_src_iter_m, dnn_mem_t &bias_m,
593 dnn_mem_t &dst_last_layer_m, dnn_mem_t &dst_last_iteration_m,
594 mkldnn_rnn_direction_t direction) {
596 assert(direction == mkldnn_unidirectional_left2right
597 || direction == mkldnn_unidirectional_right2left
598 || direction == mkldnn_bidirectional_sum
599 || direction == mkldnn_bidirectional_concat);
601 const int wc = max(p->sic, max(p->slc, p->dic));
602 int ws_size = (p->n_layer + 2) * p->n_directions() * (p->n_iter + 2)
603 * p->n_states() * p->mb * wc;
604 auto *ws = new float[ws_size];
605 int gates_size = p->n_layer * p->n_directions() * p->n_iter * p->mb
606 * p->n_gates() * p->dic;
607 auto *gates = new float[gates_size];
609 rnn_linear_fwd(p, direction, (float *)src_iter_m, (float *)src_layer_m,
610 (float *)weights_src_layer_m, (float *)weights_src_iter_m,
611 (float *)bias_m, (float *)dst_last_iteration_m,
612 (float *)dst_last_layer_m, ws, gates);
618 // =============================================================================
619 // ================ BACKWARD ===================================================
620 // =============================================================================
621 void rnn_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
622 int batch, int n_gates, float *diff_src_layer_, float *diff_src_iter_,
623 float *diff_weights_layer_, float *diff_weights_iter_h_,
624 float *diff_bias_, float *b_gates_, const float *src_layer_,
625 const float *src_iter_, const float *weights_layer_,
626 const float *weights_iter_h_, const float *bias_,
627 const float *dst_iter_h_, const float *gates_,
628 const float *diff_dst_layer_, const float *diff_dst_iter_h_) {
629 AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
630 AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
631 AOC<const float> gates(gates_, batch, n_gates, dic);
632 AOC<float> b_gates(b_gates_, batch, n_gates, dic);
634 for (int b = 0; b < batch; ++b)
635 for (int h = 0; h < dic; ++h) {
636 const float g = gates(b, 0, h);
637 const float dd = diff_dst_layer(b, h) + diff_dst_iter_h(b, h);
638 b_gates(b, 0, h) = activation(f, g, false) * dd;
641 gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_,
642 n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
643 gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
644 n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
645 for (int b = 0; b < batch; ++b)
646 copy(n_gates, dic, dic, dic, &b_gates(b, 0, 0), diff_bias_, action_sum);
648 gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
649 weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
650 gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic,
651 weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_, wc);
654 void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch,
655 int n_gates, float *diff_src_layer_, float *diff_src_iter_h_,
656 float *diff_src_iter_c_, float *diff_weights_layer_,
657 float *diff_weights_iter_h_, float *diff_bias_, float *b_gates_,
658 const float *src_layer_, const float *src_iter_h_,
659 const float *src_iter_c_, const float *weights_layer_,
660 const float *weights_iter_h_, const float *bias_,
661 const float *dst_iter_h_, const float *dst_iter_c_, const float *gates_,
662 const float *diff_dst_layer_, const float *diff_dst_iter_h_,
663 const float *diff_dst_iter_c_) {
664 // TODO: check sic and slc as last dimension in arrays and cycles
666 AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
667 AOC<const float> diff_dst_iter_c(diff_dst_iter_c_, batch, wc);
668 AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
669 AOC<const float> src_iter_c(src_iter_c_, batch, wc);
670 AOC<const float> dst_iter_h(dst_iter_h_, batch, wc);
671 AOC<const float> dst_iter_c(dst_iter_c_, batch, wc);
672 AOC<const float> gates(gates_, batch, n_gates, dic);
674 AOC<float> diff_src_iter_c(diff_src_iter_c_, batch, wc);
675 AOC<float> b_gates(b_gates_, batch, n_gates, dic);
682 for (int ib = 0; ib < batch; ib++)
683 for (int ih = 0; ih < dic; ih++) {
684 print(80, "rnn_single_bwd: ib = %d ih = %d\n", ib, ih);
685 float ho = gates(ib, oho, ih);
686 float hf = gates(ib, ohf, ih);
687 float hc = gates(ib, ohc, ih);
688 float hi = gates(ib, ohi, ih);
689 float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
690 float c = dst_iter_c(ib, ih);
691 float dho = tanhf(c) * dh;
692 b_gates(ib, oho, ih) = x_m_square(ho) * dho;
694 float dc_next = diff_dst_iter_c(ib, ih);
695 float dc = ho * dh * dtanhf(c) + dc_next;
696 diff_src_iter_c(ib, ih) = hf * dc;
698 float c_old = src_iter_c(ib, ih);
699 float dhf = c_old * dc;
700 b_gates(ib, ohf, ih) = x_m_square(hf) * dhf;
703 b_gates(ib, ohi, ih) = x_m_square(hi) * dhi;
706 b_gates(ib, ohc, ih) = one_m_square(hc) * dhc;
709 gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_h_, wc, b_gates_,
710 n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
711 gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
712 n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
714 gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_, n_gates * dic,
715 weights_iter_h_, n_gates * dic, 0.0, diff_src_iter_h_, wc);
716 gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
717 weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
719 for (int i = 0; i < batch; i++)
720 for (int j = 0; j < n_gates; j++)
721 for (int k = 0; k < dic; k++)
722 diff_bias_[j * dic + k] += b_gates(i, j, k);
725 void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
726 int batch, int n_gates, float *diff_src_layer_, float *diff_src_iter_,
727 float *diff_weights_layer_, float *diff_weights_iter_h_,
728 float *diff_bias_, float *b_gates_, const float *src_layer_,
729 const float *src_iter_, const float *weights_layer_,
730 const float *weights_iter_h_, const float *bias_,
731 const float *dst_iter_h_, const float *gates_,
732 const float *diff_dst_layer_, const float *diff_dst_iter_h_,
735 AOC<const float> src_iter(src_iter_, batch, wc);
736 AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
737 AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
738 AOC<const float> gates(gates_, batch, n_gates, dic);
739 AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
740 AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
742 AOC<float> diff_src_iter(diff_src_iter_, batch, wc);
743 AOC<float> diff_weights_iter_h(diff_weights_iter_h_, sic, n_gates, dic);
744 AOC<float> b_gates(b_gates_, batch, n_gates, dic);
746 float *dhr_ = ws_local_;
747 float *hr_ = ws_local_ + batch * wc;
748 AOC<float> dhr(dhr_, batch, wc);
749 AOC<float> hr(hr_, batch, wc);
751 // dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
752 // du = (h - u) * dh; du^ = x_m_square(u) * du;
754 // dr = h * dhr; dr^ = x_m_square(r) * dr;
758 for (int ib = 0; ib < batch; ib++)
759 for (int ih = 0; ih < dic; ih++) {
760 float h = src_iter(ib, ih);
761 float c = gates(ib, ohc, ih);
762 float u = gates(ib, ohu, ih);
763 float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
764 float du = (h - c) * dh;
765 float dc = (1.0f - u) * dh;
766 b_gates(ib, ohu, ih) = x_m_square(u) * du;
767 b_gates(ib, ohc, ih) = one_m_square(c) * dc;
768 diff_src_iter(ib, ih) = dh * u;
770 gemm("C", "N", "T", batch, sic, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic,
771 &(weights_iter_h(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc);
773 for (int ib = 0; ib < batch; ib++)
774 for (int ih = 0; ih < dic; ih++) {
775 float h = src_iter(ib, ih);
776 float r = gates(ib, ohr, ih);
777 float dr = h * dhr(ib, ih);
779 diff_src_iter(ib, ih) += dhr(ib, ih) * r;
780 b_gates(ib, ohr, ih) = x_m_square(r) * dr;
783 // dWx += xdu^ | xdr^ | xdc^
784 // dWh += hdu^ | ddr^ | (h * r)dc^
785 gemm("C", "T", "N", sic, (n_gates - 1) * dic, batch, 1.0, src_iter_, wc,
786 b_gates_, n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
787 gemm("C", "T", "N", sic, dic, batch, 1.0, hr_, wc, &(b_gates(0, 2, 0)),
788 n_gates * dic, 1.0, &(diff_weights_iter_h(0, 2, 0)), n_gates * dic);
789 gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc,
790 b_gates_, n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
792 // dx_next = Wxudu^ + Wxrdr^ + Wxcdc^
793 // dh_next = dh * u + Whudu^ + Whzdz^ + r * Whcdc^
794 gemm("C", "N", "T", batch, sic, (n_gates - 1)* dic, 1.0, b_gates_,
795 n_gates * dic, weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_,
797 gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
798 weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
800 for (int i = 0; i < batch; i++)
801 for (int j = 0; j < n_gates; j++)
802 for (int k = 0; k < dic; k++)
803 diff_bias_[j * dic + k] += b_gates(i, j, k);
806 void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
807 int batch, int n_gates, float *diff_src_layer_, float *diff_src_iter_,
808 float *diff_weights_layer_, float *diff_weights_iter_h_,
809 float *diff_bias_, float *b_gates_, const float *src_layer_,
810 const float *src_iter_, const float *weights_layer_,
811 const float *weights_iter_h_, const float *bias_,
812 const float *dst_iter_h_, const float *gates_,
813 const float *diff_dst_layer_, const float *diff_dst_iter_h_,
816 AOC<const float> src_iter(src_iter_, batch, wc);
817 AOC<const float> diff_dst_layer(diff_dst_layer_, batch, wc);
818 AOC<const float> diff_dst_iter_h(diff_dst_iter_h_, batch, wc);
819 AOC<const float> gates(gates_, batch, n_gates, dic);
820 AOC<const float> weights_layer(weights_layer_, slc, n_gates, dic);
821 AOC<const float> weights_iter_h(weights_iter_h_, sic, n_gates, dic);
822 AOC<const float> bias(bias_, n_gates + 1, dic);
824 AOC<float> diff_src_iter(diff_src_iter_, batch, wc);
825 AOC<float> diff_weights_iter_h(diff_weights_iter_h_, dic, n_gates, sic);
826 AOC<float> b_gates(b_gates_, batch, n_gates, dic);
828 float *Wh_b_ = ws_local_;
829 float *b_gates_r_ = ws_local_ + dic * batch;
830 AOC<float> Wh_b(Wh_b_, batch, dic);
831 AOC<float> b_gates_r(b_gates_r_, batch, n_gates, dic);
833 for (int ib = 0; ib < batch; ib++)
834 for (int ih = 0; ih < dic; ih++)
835 Wh_b(ib, ih) = bias(3, ih);
837 gemm("C", "N", "N", batch, dic, sic, 1.0, src_iter_, wc,
838 &weights_iter_h(0, 2, 0), n_gates * dic, 1.0, Wh_b_, dic);
841 // dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
842 // du = (h - c) * dh; du^ = x_m_square(u) * du;
843 // dr = (Wh + b) * dc^; dr^ = x_m_square(r) * dr;
847 for (int ib = 0; ib < batch; ib++)
848 for (int ih = 0; ih < dic; ih++) {
849 float h = src_iter(ib, ih);
850 float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
851 float u = gates(ib, ohu, ih);
852 float r = gates(ib, ohr, ih);
853 float c = gates(ib, ohc, ih);
854 float du = (h - c) * dh;
855 float dc = (1.0f - u) * dh;
857 b_gates(ib, ohu, ih) = x_m_square(u) * du;
858 b_gates(ib, ohc, ih) = one_m_square(c) * dc;
860 float dr = Wh_b(ib, ih) * b_gates(ib, ohc, ih);
861 b_gates(ib, ohr, ih) = x_m_square(r) * dr;
863 b_gates_r(ib, ohu, ih) = b_gates(ib, ohu, ih);
864 b_gates_r(ib, ohr, ih) = b_gates(ib, ohr, ih);
865 b_gates_r(ib, ohc, ih) = b_gates(ib, ohc, ih) * r;
866 diff_src_iter(ib, ih) = dh * u;
869 gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_, wc, b_gates_r_,
870 n_gates * dic, 1.0, diff_weights_iter_h_, n_gates * dic);
871 gemm("C", "T", "N", slc, n_gates * dic, batch, 1.0, src_layer_, wc, b_gates_,
872 n_gates * dic, 1.0, diff_weights_layer_, n_gates * dic);
874 gemm("C", "N", "T", batch, slc, n_gates * dic, 1.0, b_gates_, n_gates * dic,
875 weights_layer_, n_gates * dic, 0.0, diff_src_layer_, wc);
876 gemm("C", "N", "T", batch, sic, n_gates * dic, 1.0, b_gates_r_, n_gates * dic,
877 weights_iter_h_, n_gates * dic, 1.0, diff_src_iter_, wc);
879 for (int i = 0; i < batch; i++)
880 for (int j = 0; j < n_gates; j++)
881 for (int k = 0; k < dic; k++)
882 diff_bias_[j * dic + k] += b_gates(i, j, k);
884 for (int i = 0; i < batch; i++)
885 for (int k = 0; k < dic; k++)
886 diff_bias_[3 * dic + k] += b_gates_r(i, 2, k);
890 void rnn_cell_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
891 int batch, int n_gates, float *diff_src_layer, float *diff_src_iter_h,
892 float *diff_src_iter_c, float *diff_weights_layer,
893 float *diff_weights_iter, float *diff_bias, float *b_gates,
894 const float *src_layer, const float *src_iter_h,
895 const float *src_iter_c, const float *weights_layer,
896 const float *weights_iter, const float *bias, const float *dst_iter_h,
897 const float *dst_iter_c, const float *gates,
898 const float *diff_dst_layer, const float *diff_dst_iter_h,
899 const float *diff_dst_iter_c, float *ws_local_) {
903 lstm_bwd(alg, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
904 diff_src_iter_h, diff_src_iter_c, diff_weights_layer,
905 diff_weights_iter, diff_bias, b_gates, src_layer, src_iter_h,
906 src_iter_c, weights_layer, weights_iter, bias, dst_iter_h,
907 dst_iter_c, gates, diff_dst_layer, diff_dst_iter_h,
911 rnn_bwd(alg, f, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
912 diff_src_iter_h, diff_weights_layer, diff_weights_iter,
913 diff_bias, b_gates, src_layer, src_iter_h, weights_layer,
914 weights_iter, bias, dst_iter_h, gates, diff_dst_layer,
918 gru_bwd(alg, f, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
919 diff_src_iter_h, diff_weights_layer, diff_weights_iter,
920 diff_bias, b_gates, src_layer, src_iter_h, weights_layer,
921 weights_iter, bias, dst_iter_h, gates, diff_dst_layer,
922 diff_dst_iter_h, ws_local_);
925 gru_lbr_bwd(alg, f, sic, slc, dic, wc, batch, n_gates, diff_src_layer,
926 diff_src_iter_h, diff_weights_layer, diff_weights_iter,
927 diff_bias, b_gates, src_layer, src_iter_h, weights_layer,
928 weights_iter, bias, dst_iter_h, gates, diff_dst_layer,
929 diff_dst_iter_h, ws_local_);
934 void rnn_linear_bwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
935 const float *diff_dst_iter_, const float *diff_dst_layer_,
936 const float *weights_layer_, const float *weights_iter_h_,
937 const float *bias_, float *diff_src_iter_, float *diff_src_layer_,
938 float *diff_weights_layer_, float *diff_weights_iter_h_,
939 float *diff_bias_, float *ws_, const float *gates_) {
941 const alg_t alg = p->alg;
942 const int sic = p->sic;
943 const int slc = p->slc;
944 const int dic = p->dic;
945 const int dlc = p->dlc;
946 const int wc = max(sic, max(slc, dic));
947 bool is_lbr = p->alg == LBR_GRU;
949 const int batch = p->mb;
950 const int n_gates = p->n_gates();
951 const int n_states = p->n_states();
952 const int n_layer = p->n_layer;
953 const int n_iter = p->n_iter;
954 const int n_dir = p->n_directions();
955 activation_t f = p->activation;
957 const int X = n_states;
959 AOC<const float> bias(bias_, n_layer, n_dir, n_gates + is_lbr, dic);
960 AOC<float> diff_bias(diff_bias_, n_layer, n_dir, n_gates + is_lbr, dic);
962 AOC<const float> weights_layer(
963 weights_layer_, n_layer, n_dir, n_gates * dic, slc);
964 AOC<const float> weights_iter(
965 weights_iter_h_, n_layer, n_dir, n_gates * dic, sic);
967 AOC<float> diff_weights_layer(
968 diff_weights_layer_, n_layer, n_dir, n_gates * dic, slc);
969 AOC<float> diff_weights_iter(
970 diff_weights_iter_h_, n_layer, n_dir, n_gates * dic, sic);
972 auto *b_gates = new float[batch * n_gates * dic];
973 AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
974 AOC<const float> gates(gates_, n_layer, n_dir, n_iter, batch, n_gates, dic);
976 int wsb_size = (n_layer + 2) * n_dir * (n_iter + 2) * (n_states + 1) * batch
978 auto *wsb_ = new float[wsb_size];
979 init_buffer(wsb_, wsb_size, 0.); // ??!! Temporary. For debug.
980 // n_states + 1 -- H, C, X
982 wsb_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch, wc);
987 ws_local_size = batch * (n_gates + 1) * dic;
990 ws_local_size = 2 * batch * wc;
992 default: ws_local_size = 0;
994 float *ws_local_ = new float[ws_local_size];
996 auto process_direction = [&](rnn_iter_direction_t iter_dir,
997 rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action) {
998 // we first need to copy the initial states and input into ws
999 // it simplifies the logic in the following code
1000 copy_init_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
1001 n_dir, n_states, wsb_, diff_dst_layer_, diff_dst_iter_,
1002 iter_dir, lay_dir, dir_val,
1003 direction == mkldnn_bidirectional_concat);
1005 // We run the grid of computation
1006 for (int j = n_layer - 1; j >= 0; j--) {
1007 for (int i = 0; i < n_iter; i++) {
1008 int iter = (iter_dir == left2right) ? i + 1 : n_iter - i;
1009 int prev_iter = (iter_dir == left2right) ? iter - 1 : iter + 1;
1011 int prev_lay = lay + 1;
1013 int ws_iter = (iter_dir == left2right) ? iter : iter;
1015 = (iter_dir == left2right) ? iter + 1 : iter - 1;
1017 rnn_cell_bwd(alg, f, sic, slc, dic, wc, batch, n_gates,
1018 &wsb(lay, dir_val, iter, X, 0, 0),
1019 &wsb(lay, dir_val, iter, H, 0, 0),
1020 &wsb(lay, dir_val, iter, C, 0, 0),
1021 &diff_weights_layer(lay - 1, dir_val, 0, 0),
1022 &diff_weights_iter(lay - 1, dir_val, 0, 0),
1023 &diff_bias(lay - 1, dir_val, 0, 0), b_gates,
1024 &ws(lay - 1, dir_val, ws_iter, H, 0, 0),
1025 &ws(lay, dir_val, ws_prev_iter, H, 0, 0),
1026 &ws(lay, dir_val, ws_prev_iter, C, 0, 0),
1027 &weights_layer(lay - 1, dir_val, 0, 0),
1028 &weights_iter(lay - 1, dir_val, 0, 0),
1029 &bias(lay - 1, dir_val, 0, 0),
1030 &ws(lay, dir_val, ws_iter, H, 0, 0),
1031 &ws(lay, dir_val, ws_iter, C, 0, 0),
1032 &gates(lay - 1, dir_val, ws_iter - 1, 0, 0, 0),
1033 &wsb(prev_lay, dir_val, iter, X, 0, 0),
1034 &wsb(lay, dir_val, prev_iter, H, 0, 0),
1035 &wsb(lay, dir_val, prev_iter, C, 0, 0),
1040 // Finally we copy the results to the result buffers
1041 copy_res_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_dir,
1042 n_states, diff_src_iter_, diff_src_layer_, wsb_, iter_dir,
1043 lay_dir, dir_val, action);
1046 switch (direction) {
1047 case mkldnn_unidirectional_left2right:
1048 process_direction(right2left, top2bottom, 0, action_copy);
1050 case mkldnn_unidirectional_right2left:
1051 process_direction(left2right, top2bottom, 0, action_copy);
1053 case mkldnn_bidirectional_sum:
1054 process_direction(right2left, top2bottom, 0, action_copy);
1055 process_direction(left2right, top2bottom, 1, action_sum);
1057 case mkldnn_bidirectional_concat:
1058 process_direction(right2left, top2bottom, 0, action_copy);
1059 process_direction(left2right, top2bottom, 1, action_sum);
1061 default: assert("unknown direction"); break;
1069 void compute_ref_bwd(const rnn_prb_t *p, dnn_mem_t &input_m,
1070 dnn_mem_t &states_m, dnn_mem_t &diff_last_layer_m,
1071 dnn_mem_t &diff_last_iteration_m, dnn_mem_t &weights_input_m,
1072 dnn_mem_t &weights_states_m, dnn_mem_t &bias_m,
1073 dnn_mem_t &dst_last_layer_m, dnn_mem_t &dst_last_iteration_m,
1074 dnn_mem_t &dst_diff_input_m, dnn_mem_t &dst_diff_states_m,
1075 dnn_mem_t &dst_diff_weights_input_m,
1076 dnn_mem_t &dst_diff_weights_states_m, dnn_mem_t &dst_diff_bias_m,
1077 mkldnn_rnn_direction_t direction) {
1078 // !! TODO: add support of strides
1080 assert(direction == mkldnn_unidirectional_left2right
1081 || direction == mkldnn_unidirectional_right2left
1082 || direction == mkldnn_bidirectional_sum
1083 || direction == mkldnn_bidirectional_concat);
1085 assert(p->dlc == p->dic);
1086 int wc = max(p->sic, max(p->slc, p->dic));
1087 int ws_size = (p->n_layer + 2) * p->n_directions() * (p->n_iter + 2)
1088 * p->n_states() * p->mb * wc;
1089 auto *ws = new float[ws_size];
1090 init_buffer(ws, ws_size, -55.); // ??!! Temporary. For debug.
1091 int gates_size = p->n_layer * p->n_directions() * p->n_iter * p->mb
1092 * p->n_gates() * p->dic;
1093 auto *gates = new float[gates_size];
1095 rnn_linear_fwd(p, direction, (float *)states_m, (float *)input_m,
1096 (float *)weights_input_m, (float *)weights_states_m,
1097 (float *)bias_m, (float *)dst_last_iteration_m,
1098 (float *)dst_last_layer_m, ws, gates);
1100 rnn_linear_bwd(p, direction, (float *)diff_last_iteration_m,
1101 (float *)diff_last_layer_m, (float *)weights_input_m,
1102 (float *)weights_states_m, (float *)bias_m,
1103 (float *)dst_diff_states_m, (float *)dst_diff_input_m,
1104 (float *)dst_diff_weights_input_m,
1105 (float *)dst_diff_weights_states_m, (float *)dst_diff_bias_m, ws,