float result = 0;
switch (f) {
case RELU: result = is_fwd ? relu(x) : drelu(x); break;
- case LOGISTIC: result = is_fwd ? logistic(x) : dlogistic(x); break;
- case TANH: result = is_fwd ? tanhf(x) : dtanhf(x); break;
+ case LOGISTIC: result = is_fwd ? logistic(x) : x_m_square(x); break;
+ case TANH: result = is_fwd ? tanhf(x) : one_m_square(x); break;
default: assert(!"unknown activation");
}
return result;
}
// w = [weights_layer | weights_iter] : with order f, i , o, \bar(c)
-void lstm_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
- float *dst_iter_h_, float *c_dst_, float *gates_,
+void lstm_fwd(const rnn_prb_t *p, int sic, int slc, int dic, int wc, int batch,
+ int n_gates, float *dst_iter_h_, float *c_dst_, float *gates_,
const float *weights_layer_, const float *weights_iter_h_,
const float *bias_, const float *src_layer_, const float *src_iter_h_,
const float *src_iter_c_) {
gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
- gemm("C", "N", "N", batch, n_gates * dic, sic,1.0, src_iter_h_, wc,
+ gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
+ auto maybe_deq_w = [&](float g, int oc) {
+ if (p->cfg == conf_f32)
+ return g;
+ float scale = 1.;
+ if (p->scale_policy == PER_OC)
+ scale = p->wei_oc_scales[oc];
+ else if (p->scale_policy == COMMON)
+ scale = p->wei_scale;
+ scale *= p->data_scale;
+ return g / scale;
+ };
+
// add bias
for (int i = 0; i < batch; i++)
for (int j = 0; j < n_gates; j++)
for (int k = 0; k < dic; k++) {
- gates(i, j, k) += bias(j, k);
+ gates(i, j, k)
+ = maybe_deq_w(gates(i, j, k), j * dic + k) + bias(j, k);
}
// run the eltwise
lstm_activation(dic, n_gates, batch, gates_);
+ auto maybe_q_d = [&](float h) {
+ if (p->cfg == conf_f32)
+ return h;
+ float fp = p->data_scale * h;
+ using R = attr_t::round_mode_t;
+ switch (p->attr.irmode) {
+ case R::DOWN: fp = floorf(fp); break;
+ case R::NEAREST: fp = nearbyintf(fp); break;
+ default: assert(!"unkown round mode");
+ }
+ if (fp + p->data_shift > p->cfg[input].max)
+ fp = p->cfg[input].max - p->data_shift;
+ if (fp + p->data_shift < p->cfg[input].min)
+ fp = p->cfg[input].min - p->data_shift;
+ return fp;
+ };
+
// compute C_t_l and H_t_l
for (int i = 0; i < batch; i++)
for (int j = 0; j < dic; j++) {
float tmp = gates(i, ohf, j) * src_iter_c(i, j)
+ gates(i, ohi, j) * gates(i, ohc, j);
c_dst(i, j) = tmp;
- h_dst(i, j) = gates(i, oho, j) * tanhf(tmp);
+ h_dst(i, j) = maybe_q_d(gates(i, oho, j) * tanhf(tmp));
}
}
-void rnn_cell_fwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
- int batch, int n_gates, float *dst_iter_h, float *dst_iter_c,
- float *gates, const float *weights_layer, const float *weights_iter,
- const float *bias, const float *src_layer, const float *src_iter_h,
- const float *src_iter_c, float *ws_local_) {
+void rnn_cell_fwd(const rnn_prb_t *p, alg_t alg, activation_t f, int sic,
+ int slc, int dic, int wc, int batch, int n_gates, float *dst_iter_h,
+ float *dst_iter_c, float *gates, const float *weights_layer,
+ const float *weights_iter, const float *bias, const float *src_layer,
+ const float *src_iter_h, const float *src_iter_c, float *ws_local_) {
switch (alg) {
case VANILLA_GRU:
gru_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
ws_local_);
break;
case VANILLA_LSTM:
- lstm_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, dst_iter_c,
+ lstm_fwd(p, sic, slc, dic, wc, batch, n_gates, dst_iter_h, dst_iter_c,
gates, weights_layer, weights_iter, bias, src_layer, src_iter_h,
src_iter_c);
break;
default: break;
}
}
+
void copy(int dimc, int dimr, int ld_src, int ld_dst, const float *src_,
float *dst_, rnn_action_t action = action_copy) {
AOC<const float> src(src_, dimc, ld_src);
});
}
-/* FIXME: separate copy_init ???
- * fwd: ws_states = n_states
- * bwd: ws_states = n_states + 1
- *
- * lstm example:
+void shift(int dimc, int dimr, int ld_src, float *src_, float shift,
+ bool round = false, const rnn_prb_t *p = nullptr) {
+ AOC<float> src(src_, dimc, ld_src);
+ mkldnn::impl::parallel_nd(dimc, [&](int i) {
+ for (int j = 0; j < dimr; j++) {
+ float fp = src(i, j) + shift;
+ if (round) {
+ using R = attr_t::round_mode_t;
+ switch (p->attr.irmode) {
+ case R::DOWN: fp = floorf(fp); break;
+ case R::NEAREST: fp = nearbyintf(fp); break;
+ default: assert(!"unkown round mode");
+ }
+ if (fp > UINT8_MAX)
+ fp = UINT8_MAX;
+ if (fp < 0)
+ fp = 0;
+ }
+ src(i, j) = fp;
+ }
+ });
+}
+
+void scale(int dimc, int dimr, int ld_src, float *src_, float scale,
+ bool round = false, const rnn_prb_t *p = nullptr) {
+ AOC<float> src(src_, dimc, ld_src);
+ mkldnn::impl::parallel_nd(dimc, [&](int i) {
+ for (int j = 0; j < dimr; j++) {
+ float fp = src(i, j) * scale;
+ if (round) {
+ using R = attr_t::round_mode_t;
+ switch (p->attr.irmode) {
+ case R::DOWN: fp = floorf(fp); break;
+ case R::NEAREST: fp = nearbyintf(fp); break;
+ default: assert(!"unkown round mode");
+ }
+ }
+ src(i, j) = fp;
+ }
+ });
+}
+
+/* lstm example:
* fwd: ws keeps {h, c} for every cell
- * bwd: wsb keeps {dh, dc, dx} for every cell
*/
-void copy_init(alg_t alg, int sic, int slc, int dic, int dlc, int wc, int batch,
- int n_layer, int n_iter, int n_states, float *ws_,
+void copy_init_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
+ int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
+ int n_states, float *ws_, const float *src_layer_,
+ const float *firstit_states_, rnn_iter_direction_t iter_dir,
+ rnn_layer_direction_t lay_dir, int dir_val) {
+ AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch * wc);
+ AOC<const float> src_layer(src_layer_, n_iter, batch * slc);
+ AOC<const float> firstit_states(
+ firstit_states_, n_layer, n_dir, n_states, batch * sic);
+
+ int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
+ int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
+ bool is_int8 = p->cfg[input].dt == mkldnn_u8;
+
+ // Copy input
+ for (int it = 0; it < n_iter; it++) {
+ copy(batch, slc, slc, wc, &src_layer(it, 0),
+ &ws(lay_dest, dir_val, it + 1, H, 0));
+ if (p->cfg[input].dt == mkldnn_u8)
+ // shift u8 input to s8 to avoid compensation in gemm
+ shift(batch, slc, wc, &ws(lay_dest, dir_val, it + 1, H, 0),
+ -1. * p->data_shift);
+ }
+
+ // Copy states
+ for (int lay = 0; lay < n_layer; lay++) {
+ copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, H, 0),
+ &ws(lay + 1, dir_val, it_dest, H, 0));
+ if (p->cfg[states].dt == mkldnn_u8)
+ shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
+ -1. * p->data_shift);
+ else if (p->cfg[states].dt == mkldnn_f32 && is_int8) {
+ // quantize to s8
+ scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
+ p->data_scale, true, p);
+ }
+
+ if (alg == VANILLA_LSTM) {
+ copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, C, 0),
+ &ws(lay + 1, dir_val, it_dest, C, 0));
+ if (p->cfg[states].dt == mkldnn_u8) {
+ // dequantize to f32
+ shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
+ -1. * p->data_shift);
+ scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
+ 1. / p->data_scale);
+ }
+ }
+ }
+}
+
+/* lstm example:
+ * bwd: wsb keeps {dh, dc, dx} for every cell
+*/
+void copy_init_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
+ int batch, int n_layer, int n_iter, int n_dir, int n_states, float *ws_,
const float *src_layer_, const float *firstit_states_,
rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
- int dir_val, int n_dir, bool is_bwd = false, bool is_concat = false) {
+ int dir_val, bool is_concat = false) {
AOC<float> ws(
- ws_, n_layer + 2, n_dir, n_iter + 2, n_states + is_bwd, batch, wc);
- auto c_stride = is_bwd ? (is_concat ? 2 * dlc : dlc) : slc;
+ ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch * wc);
+ auto c_stride = is_concat ? 2 * dlc : dlc;
AOC<const float> src_layer(src_layer_, n_iter, batch * c_stride);
- AOC<const float> firstit_states(firstit_states_, n_layer, n_dir, n_states,
- batch, is_bwd ? dic : sic);
+ AOC<const float> firstit_states(
+ firstit_states_, n_layer, n_dir, n_states, batch * dic);
int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
- if (!is_bwd) {
- for (int it = 0; it < n_iter; it++)
- copy(batch, slc, slc, wc, &src_layer(it, 0),
- &ws(lay_dest, dir_val, it + 1, H, 0, 0));
-
- for (int lay = 0; lay < n_layer; lay++) {
- copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, H, 0, 0),
- &ws(lay + 1, dir_val, it_dest, H, 0, 0));
- if (alg == VANILLA_LSTM) {
- copy(batch, sic, sic, wc,
- &firstit_states(lay, dir_val, C, 0, 0),
- &ws(lay + 1, dir_val, it_dest, C, 0, 0));
+ for (int it = 0; it < n_iter; it++)
+ copy(batch, dic, c_stride, wc,
+ &src_layer(it, dir_val * is_concat * dlc),
+ &ws(lay_dest, dir_val, it + 1, n_states, 0));
+
+ for (int lay = 0; lay < n_layer; lay++) {
+ copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, H, 0),
+ &ws(lay + 1, dir_val, it_dest, H, 0));
+ if (alg == VANILLA_LSTM) {
+ copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, C, 0),
+ &ws(lay + 1, dir_val, it_dest, C, 0));
+ }
+ }
+}
+
+void copy_res_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
+ int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
+ int n_states, float *lastit_states_, float *lastlay_states_,
+ const float *ws_, rnn_iter_direction_t iter_dir,
+ rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action,
+ bool is_concat = false) {
+ int lastlay_c = is_concat ? 2 * dlc : dlc;
+ AOC<float> lastit_states(
+ lastit_states_, n_layer, n_dir, n_states, batch, dic);
+ AOC<float> lastlay_states(lastlay_states_, n_iter, batch, lastlay_c);
+ AOC<const float> ws(
+ ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
+
+ // Copy states layer
+ for (int it = 0; it < n_iter; it++) {
+ for (int nb = 0; nb < batch; nb++) {
+ auto from = &ws(n_layer, dir_val, it + 1, H, nb, 0);
+ auto to = &lastlay_states(
+ it, nb, action == action_concat ? dlc : 0);
+ copy(1, dlc, wc, lastlay_c, from, to, action);
+
+ if (p->cfg[dst_last_layer].dt == mkldnn_u8) {
+ // shift s8 internal ws to u8
+ shift(1, dlc, lastlay_c, to, p->data_shift);
+ } else {
+ // dequantize to f32
+ scale(1, dlc, lastlay_c, to, 1. / p->data_scale);
}
}
- } else {
- for (int it = 0; it < n_iter; it++)
- copy(batch, dic, c_stride, wc,
- &src_layer(it, dir_val * is_concat * dlc),
- &ws(lay_dest, dir_val, it + 1, n_states, 0, 0));
-
- for (int lay = 0; lay < n_layer; lay++) {
- copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, H, 0, 0),
- &ws(lay + 1, dir_val, it_dest, H, 0, 0));
- if (alg == VANILLA_LSTM) {
- copy(batch, dic, dic, wc,
- &firstit_states(lay, dir_val, C, 0, 0),
- &ws(lay + 1, dir_val, it_dest, C, 0, 0));
+ }
+
+ int it_source = (iter_dir == left2right) ? n_iter : 1;
+
+ // Copy states iteration
+ for (int lay = 0; lay < n_layer; lay++) {
+ if (alg == VANILLA_LSTM) {
+ copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
+ &lastit_states(lay, dir_val, C, 0, 0));
+ if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
+ // quantize internal f32 ws to u8
+ scale(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
+ p->data_scale);
+ shift(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
+ p->data_shift, true, p);
}
}
+ copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
+ &lastit_states(lay, dir_val, H, 0, 0));
+ if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
+ // shift s8 internal ws to u8
+ shift(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
+ p->data_shift);
+ } else {
+ // dequantize to f32
+ scale(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
+ 1. / p->data_scale);
+ }
}
}
-void copy_res(alg_t alg, int sic, int slc, int dic, int dlc, int wc, int batch,
- int n_layer, int n_iter, int n_states, float *lastit_states_,
- float *lastlay_states_, const float *ws_,
- mkldnn_rnn_direction_t direction, rnn_iter_direction_t iter_dir,
- rnn_layer_direction_t lay_dir, int dir_val, int n_dir,
- rnn_action_t action, bool is_bwd = false) {
- int lastlay_c = is_bwd ?
- slc :
- (direction == mkldnn_bidirectional_concat) * dlc + dlc;
- int lastiter_c = is_bwd ? sic : dic;
+void copy_res_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
+ int batch, int n_layer, int n_iter, int n_dir, int n_states,
+ float *lastit_states_, float *lastlay_states_, const float *ws_,
+ rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
+ int dir_val, rnn_action_t action) {
AOC<float> lastit_states(
- lastit_states_, n_layer, n_dir, n_states, batch, lastiter_c);
- AOC<float> lastlay_states(lastlay_states_, n_iter, batch, lastlay_c);
+ lastit_states_, n_layer, n_dir, n_states, batch, sic);
+ AOC<float> lastlay_states(lastlay_states_, n_iter, batch, slc);
AOC<const float> ws(
- ws_, n_layer + 2, n_dir, n_iter + 2, n_states + is_bwd, batch, wc);
+ ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch, wc);
for (int it = 0; it < n_iter; it++) {
for (int nb = 0; nb < batch; nb++) {
// copy H to last layer states
- int lay = is_bwd ? 1 : n_layer;
- int state = is_bwd ? n_states : H;
- auto from = &ws(lay, dir_val, it + 1, state, nb, 0);
- auto to = &lastlay_states(
- it, nb, (action == action_concat) && (!is_bwd) ? dlc : 0);
+ auto from = &ws(1, dir_val, it + 1, n_states, nb, 0);
+ auto to = &lastlay_states(it, nb, 0);
- copy(1, is_bwd ? slc : dlc, wc, lastlay_c, from, to, action);
+ copy(1, slc, wc, slc, from, to, action);
}
}
for (int lay = 0; lay < n_layer; lay++) {
if (alg == VANILLA_LSTM) {
- copy(batch, lastiter_c, wc, lastiter_c,
- &ws(lay + 1, dir_val, it_source, C, 0, 0),
+ copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
&lastit_states(lay, dir_val, C, 0, 0));
}
- copy(batch, lastiter_c, wc, lastiter_c,
- &ws(lay + 1, dir_val, it_source, H, 0, 0),
+ copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
&lastit_states(lay, dir_val, H, 0, 0));
}
}
const int dlc = p->dlc;
const int wc = max(sic, max(slc, dic));
bool is_lbr = p->alg == LBR_GRU;
+ bool is_concat = direction == mkldnn_bidirectional_concat;
const int batch = p->mb;
const int n_gates = p->n_gates();
// we first need to copy the initial states and input into ws
// it simplifies the logic in the following code
print(80, "rnn_linear_fwd: call copy_init dir_val = %d\n", dir_val);
- copy_init(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states, ws_,
- src_layer_, src_iter_, iter_dir, lay_dir, dir_val, n_dir);
+ copy_init_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
+ n_dir, n_states, ws_, src_layer_, src_iter_, iter_dir, lay_dir,
+ dir_val);
// We run the grid of computation
for (int il = 0; il < n_layer; il++) {
int iter = (iter_dir == left2right) ? it + 1 : n_iter - it;
int prev_iter = (iter_dir == left2right) ? iter - 1 : iter + 1;
int lay = il + 1;
- rnn_cell_fwd(alg, f, sic, slc, dic, wc, batch, n_gates,
+ rnn_cell_fwd(p, alg, f, sic, slc, dic, wc, batch, n_gates,
&ws(lay, dir_val, iter, H, 0, 0),
&ws(lay, dir_val, iter, C, 0, 0),
&gates(lay - 1, dir_val, iter - 1, 0, 0, 0),
&bias(lay - 1, dir_val, 0),
&ws(lay - 1, dir_val, iter, H, 0, 0),
&ws(lay, dir_val, prev_iter, H, 0, 0),
- &ws(lay, dir_val, prev_iter, C, 0, 0),
- ws_local_);
+ &ws(lay, dir_val, prev_iter, C, 0, 0), ws_local_);
}
}
// Finally we copy the results to the result buffers
- copy_res(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states,
- dst_iter_, dst_layer_, ws_, direction, iter_dir, lay_dir,
- dir_val, n_dir, action);
+ copy_res_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
+ n_dir, n_states, dst_iter_, dst_layer_, ws_, iter_dir, lay_dir,
+ dir_val, action, is_concat);
};
switch (direction) {
float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
float c = dst_iter_c(ib, ih);
float dho = tanhf(c) * dh;
- b_gates(ib, oho, ih) = dlogistic(ho) * dho;
+ b_gates(ib, oho, ih) = x_m_square(ho) * dho;
float dc_next = diff_dst_iter_c(ib, ih);
float dc = ho * dh * dtanhf(c) + dc_next;
float c_old = src_iter_c(ib, ih);
float dhf = c_old * dc;
- b_gates(ib, ohf, ih) = dlogistic(hf) * dhf;
+ b_gates(ib, ohf, ih) = x_m_square(hf) * dhf;
float dhi = hc * dc;
- b_gates(ib, ohi, ih) = dlogistic(hi) * dhi;
+ b_gates(ib, ohi, ih) = x_m_square(hi) * dhi;
float dhc = hi * dc;
- b_gates(ib, ohc, ih) = dtanhf(hc) * dhc;
+ b_gates(ib, ohc, ih) = one_m_square(hc) * dhc;
}
gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_h_, wc, b_gates_,
AOC<float> dhr(dhr_, batch, wc);
AOC<float> hr(hr_, batch, wc);
-// dc = (1 - u) * dh; dc^ = dtanhf(c) * dc;
-// du = (h - u) * dh; du^ = dlogistic(u) * du;
+// dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
+// du = (h - u) * dh; du^ = x_m_square(u) * du;
// dhr = Wc dc^;
-// dr = h * dhr; dr^ = dlogistic(r) * dr;
+// dr = h * dhr; dr^ = x_m_square(r) * dr;
const int ohu = 0;
const int ohr = 1;
const int ohc = 2;
float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
float du = (h - c) * dh;
float dc = (1.0f - u) * dh;
- b_gates(ib, ohu, ih) = dlogistic(u) * du;
- b_gates(ib, ohc, ih) = dtanhf(c) * dc;
+ b_gates(ib, ohu, ih) = x_m_square(u) * du;
+ b_gates(ib, ohc, ih) = one_m_square(c) * dc;
diff_src_iter(ib, ih) = dh * u;
}
- gemm("C", "N", "T", batch, slc, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic,
- &(weights_layer(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc);
+ gemm("C", "N", "T", batch, sic, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic,
+ &(weights_iter_h(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc);
for (int ib = 0; ib < batch; ib++)
for (int ih = 0; ih < dic; ih++) {
float dr = h * dhr(ib, ih);
hr(ib, ih) = h * r;
diff_src_iter(ib, ih) += dhr(ib, ih) * r;
- b_gates(ib, ohr, ih) = dlogistic(r) * dr;
+ b_gates(ib, ohr, ih) = x_m_square(r) * dr;
}
// dWx += xdu^ | xdr^ | xdc^
&weights_iter_h(0, 2, 0), n_gates * dic, 1.0, Wh_b_, dic);
-// dc = (1 - u) * dh; dc^ = dtanhf(c) * dc;
-// du = (h - u) * dh; du^ = dlogistic(u) * du;
-// dr = (Wh + b) * dc^; dr^ = dlogistic(r) * dr;
+// dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
+// du = (h - c) * dh; du^ = x_m_square(u) * du;
+// dr = (Wh + b) * dc^; dr^ = x_m_square(r) * dr;
const int ohu = 0;
const int ohr = 1;
const int ohc = 2;
float du = (h - c) * dh;
float dc = (1.0f - u) * dh;
- b_gates(ib, ohu, ih) = dlogistic(u) * du;
- b_gates(ib, ohc, ih) = dtanhf(c) * dc;
+ b_gates(ib, ohu, ih) = x_m_square(u) * du;
+ b_gates(ib, ohc, ih) = one_m_square(c) * dc;
float dr = Wh_b(ib, ih) * b_gates(ib, ohc, ih);
- b_gates(ib, ohr, ih) = dlogistic(r) * dr;
+ b_gates(ib, ohr, ih) = x_m_square(r) * dr;
b_gates_r(ib, ohu, ih) = b_gates(ib, ohu, ih);
b_gates_r(ib, ohr, ih) = b_gates(ib, ohr, ih);
rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action) {
// we first need to copy the initial states and input into ws
// it simplifies the logic in the following code
- copy_init(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states,
- wsb_, diff_dst_layer_, diff_dst_iter_, iter_dir, lay_dir,
- dir_val, n_dir, true, direction == mkldnn_bidirectional_concat);
+ copy_init_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
+ n_dir, n_states, wsb_, diff_dst_layer_, diff_dst_iter_,
+ iter_dir, lay_dir, dir_val,
+ direction == mkldnn_bidirectional_concat);
// We run the grid of computation
for (int j = n_layer - 1; j >= 0; j--) {
}
// Finally we copy the results to the result buffers
- copy_res(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states,
- diff_src_iter_, diff_src_layer_, wsb_, direction, iter_dir,
- lay_dir, dir_val, n_dir, action, true);
+ copy_res_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_dir,
+ n_states, diff_src_iter_, diff_src_layer_, wsb_, iter_dir,
+ lay_dir, dir_val, action);
};
switch (direction) {