end_timestep = cur_ts - 1;
}
- if (start_timestep == max_timestep - 1) {
+ if (start_timestep + 1 == max_timestep) {
djdw_x.setZero();
djdw_h.setZero();
djdb_h.setZero();
} else {
TensorDim d = derivative_.getDim();
for (unsigned int b = 0; b < input_dim.batch(); ++b) {
- float *data = derivative_.getAddress(b * d.width() * d.height() +
- start_timestep * d.width());
- float *rdata = incoming_deriv.getAddress(b * d.width());
- std::copy(rdata, rdata + d.width(), data);
+ Tensor data = derivative_.getSharedDataTensor(
+ {d.width()}, b * d.width() * d.height() + start_timestep * d.width());
+
+ Tensor rdata =
+ incoming_deriv.getSharedDataTensor({d.width()}, b * d.width());
+ /// @note this is not copying from start ~ end but only start time step
+ if ((unsigned)start_timestep + 1 == max_timestep) {
+ data.fill(rdata);
+ } else {
+ data.add_i(rdata);
+ }
}
}
derivative_.multiply_i(context.getTensor(wt_idx[LSTMParams::dropout_mask]));
}
- Tensor dh_nx = Tensor(derivative_.width());
- Tensor dc_nx = Tensor(derivative_.width());
-
for (unsigned int b = 0; b < input_dim.batch(); ++b) {
Tensor deriv_t = derivative_.getBatchSlice(b, 1);
Tensor derivc_t = dm_cell_.getBatchSlice(b, 1);
Tensor hs_t = hidden_.getBatchSlice(b, 1);
Tensor cs_t = m_cell_.getBatchSlice(b, 1);
- dc_nx.setZero();
- dh_nx.setZero();
-
Tensor dh;
Tensor xs;
Tensor hs_prev;
Tensor fgio_ = fgio.getBatchSlice(b, 1);
for (int t = start_timestep; t > end_timestep; t--) {
- if (deriv_t.height() != 1)
- dh =
- deriv_t.getSharedDataTensor({deriv_t.width()}, t * deriv_t.width());
- else
- dh = deriv_t;
+ dh = deriv_t.getSharedDataTensor({deriv_t.width()}, t * deriv_t.width());
dc =
derivc_t.getSharedDataTensor({derivc_t.width()}, t * derivc_t.width());
cs_t.getSharedDataTensor({cs_t.width()}, (t - 1) * cs_t.width());
}
- if ((unsigned int)t < deriv_t.height() - 1) {
- dh.add_i(dh_nx);
- }
-
Tensor dhi = dfgio_t.getSharedDataTensor({unit}, 0);
Tensor dhf = dfgio_t.getSharedDataTensor({unit}, unit);
Tensor dhg = dfgio_t.getSharedDataTensor({unit}, unit * 2);
acti_func.run_fn(cs, dho);
dho.multiply_i(dh);
acti_func.run_fn(cs, cs);
- acti_func.run_prime_fn(cs, dc, dh);
- dc.multiply_i(ho);
- dc.add_i(dc_nx);
+
+ if ((unsigned)t + 1 == max_timestep) {
+ acti_func.run_prime_fn(cs, dc, dh);
+ dc.multiply_i(ho);
+ } else {
+ /// @todo optimize this by updating run_prime_fn to accumulate
+ Tensor dc_temp(dc.getDim());
+ acti_func.run_prime_fn(cs, dc_temp, dh);
+ dc_temp.multiply_i(ho);
+ dc.add_i(dc_temp);
+ }
+
+ if (t > 0) {
+ Tensor dc_nx = derivc_t.getSharedDataTensor({derivc_t.width()},
+ (t - 1) * derivc_t.width());
+ dc.multiply(hf, dc_nx);
+ }
dc.multiply(cs_prev, dhf);
dc.multiply(hg, dhi);
dc.multiply(hi, dhg);
- dc.multiply(hf, dc_nx);
recurrent_acti_func.run_prime_fn(ho, dho, dho);
recurrent_acti_func.run_prime_fn(hf, dhf, dhf);
djdb_h.add_i(dfgio_t);
djdw_x.add_i(xs.dot(dfgio_t, true, false));
djdw_h.add_i(hs_prev.dot(dfgio_t, true, false));
- dfgio_t.dot(weight_hh, dh_nx, false, true);
+ if (t > 0) {
+ Tensor dh_nx = deriv_t.getSharedDataTensor({deriv_t.width()},
+ (t - 1) * deriv_t.width());
+ dfgio_t.dot(weight_hh, dh_nx, false, true, 1.0f);
+ }
}
}
}