Tensor fgio_t = fgio.getBatchSlice(start_timestep, 1);
input_.dot(weight_xh, fgio_t);
- fgio_t.add_i(bias_h);
if (start_timestep > 0) {
Tensor hs_prev = hidden_.getBatchSlice(start_timestep - 1, 1);
hs_prev.dot(weight_hh, fgio_t, false, false, 1.0);
}
+ fgio_t.add_i(bias_h);
Tensor hif = fgio_t.getSharedDataTensor({batch, unit * 2}, 0, false);
Tensor hi = fgio_t.getSharedDataTensor({batch, unit}, 0, false);
Tensor hf = fgio_t.getSharedDataTensor({batch, unit}, unit, false);
djdw_x.setZero();
djdw_h.setZero();
djdb_h.setZero();
-
- dm_cell_.setZero();
- derivative_.setZero();
- d_fgio.setZero();
}
Tensor dh = derivative_.getBatchSlice(start_timestep, 1);
recurrent_acti_func.run_prime_fn(ho, dho, dho);
recurrent_acti_func.run_prime_fn(hif, dhif, dhif);
acti_func.run_prime_fn(hg, dhg, dhg);
- djdb_h.add_i(dfgio_t.sum(2));
- djdw_x.add_i(xs.dot(dfgio_t, true, false));
+ dfgio_t.sum(2, djdb_h, 1.0, 1.0);
+
+ xs.dot(dfgio_t, djdw_x, true, false, 1.0f);
if (start_timestep != 0) {
Tensor hs_prev = hidden_.getBatchSlice(start_timestep - 1, 1);
- djdw_h.add_i(hs_prev.dot(dfgio_t, true, false));
+ hs_prev.dot(dfgio_t, djdw_h, true, false, 1.0f);
Tensor dh_nx = derivative_.getBatchSlice(start_timestep - 1, 1);
dfgio_t.dot(weight_hh, dh_nx, false, true, 1.0f);
}