[LSTM] Fix single time step backwarding
authorJihoon Lee <jhoon.it.lee@samsung.com>
Tue, 19 Oct 2021 10:03:30 +0000 (19:03 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 20 Oct 2021 12:19:55 +0000 (21:19 +0900)
This patch updates single time step backwarding to properly set and see
the last timestep values

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/layers/lstm.cpp

index b61090d62260f26732bf931fed9bb9a87af574c8..9248c3a500c97885f280e73c6213b40a5f9ca534 100644 (file)
@@ -371,7 +371,7 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
     end_timestep = cur_ts - 1;
   }
 
-  if (start_timestep == max_timestep - 1) {
+  if (start_timestep + 1 == max_timestep) {
     djdw_x.setZero();
     djdw_h.setZero();
     djdb_h.setZero();
@@ -389,10 +389,17 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
   } else {
     TensorDim d = derivative_.getDim();
     for (unsigned int b = 0; b < input_dim.batch(); ++b) {
-      float *data = derivative_.getAddress(b * d.width() * d.height() +
-                                           start_timestep * d.width());
-      float *rdata = incoming_deriv.getAddress(b * d.width());
-      std::copy(rdata, rdata + d.width(), data);
+      Tensor data = derivative_.getSharedDataTensor(
+        {d.width()}, b * d.width() * d.height() + start_timestep * d.width());
+
+      Tensor rdata =
+        incoming_deriv.getSharedDataTensor({d.width()}, b * d.width());
+      /// @note this is not copying from start ~ end but only start time step
+      if ((unsigned)start_timestep + 1 == max_timestep) {
+        data.fill(rdata);
+      } else {
+        data.add_i(rdata);
+      }
     }
   }
 
@@ -400,9 +407,6 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
     derivative_.multiply_i(context.getTensor(wt_idx[LSTMParams::dropout_mask]));
   }
 
-  Tensor dh_nx = Tensor(derivative_.width());
-  Tensor dc_nx = Tensor(derivative_.width());
-
   for (unsigned int b = 0; b < input_dim.batch(); ++b) {
     Tensor deriv_t = derivative_.getBatchSlice(b, 1);
     Tensor derivc_t = dm_cell_.getBatchSlice(b, 1);
@@ -410,9 +414,6 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
     Tensor hs_t = hidden_.getBatchSlice(b, 1);
     Tensor cs_t = m_cell_.getBatchSlice(b, 1);
 
-    dc_nx.setZero();
-    dh_nx.setZero();
-
     Tensor dh;
     Tensor xs;
     Tensor hs_prev;
@@ -423,11 +424,7 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
     Tensor fgio_ = fgio.getBatchSlice(b, 1);
 
     for (int t = start_timestep; t > end_timestep; t--) {
-      if (deriv_t.height() != 1)
-        dh =
-          deriv_t.getSharedDataTensor({deriv_t.width()}, t * deriv_t.width());
-      else
-        dh = deriv_t;
+      dh = deriv_t.getSharedDataTensor({deriv_t.width()}, t * deriv_t.width());
 
       dc =
         derivc_t.getSharedDataTensor({derivc_t.width()}, t * derivc_t.width());
@@ -456,10 +453,6 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
           cs_t.getSharedDataTensor({cs_t.width()}, (t - 1) * cs_t.width());
       }
 
-      if ((unsigned int)t < deriv_t.height() - 1) {
-        dh.add_i(dh_nx);
-      }
-
       Tensor dhi = dfgio_t.getSharedDataTensor({unit}, 0);
       Tensor dhf = dfgio_t.getSharedDataTensor({unit}, unit);
       Tensor dhg = dfgio_t.getSharedDataTensor({unit}, unit * 2);
@@ -473,14 +466,27 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
       acti_func.run_fn(cs, dho);
       dho.multiply_i(dh);
       acti_func.run_fn(cs, cs);
-      acti_func.run_prime_fn(cs, dc, dh);
-      dc.multiply_i(ho);
-      dc.add_i(dc_nx);
+
+      if ((unsigned)t + 1 == max_timestep) {
+        acti_func.run_prime_fn(cs, dc, dh);
+        dc.multiply_i(ho);
+      } else {
+        /// @todo optimize this by updating run_prime_fn to accumulate
+        Tensor dc_temp(dc.getDim());
+        acti_func.run_prime_fn(cs, dc_temp, dh);
+        dc_temp.multiply_i(ho);
+        dc.add_i(dc_temp);
+      }
+
+      if (t > 0) {
+        Tensor dc_nx = derivc_t.getSharedDataTensor({derivc_t.width()},
+                                                    (t - 1) * derivc_t.width());
+        dc.multiply(hf, dc_nx);
+      }
 
       dc.multiply(cs_prev, dhf);
       dc.multiply(hg, dhi);
       dc.multiply(hi, dhg);
-      dc.multiply(hf, dc_nx);
 
       recurrent_acti_func.run_prime_fn(ho, dho, dho);
       recurrent_acti_func.run_prime_fn(hf, dhf, dhf);
@@ -489,7 +495,11 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
       djdb_h.add_i(dfgio_t);
       djdw_x.add_i(xs.dot(dfgio_t, true, false));
       djdw_h.add_i(hs_prev.dot(dfgio_t, true, false));
-      dfgio_t.dot(weight_hh, dh_nx, false, true);
+      if (t > 0) {
+        Tensor dh_nx = deriv_t.getSharedDataTensor({deriv_t.width()},
+                                                   (t - 1) * deriv_t.width());
+        dfgio_t.dot(weight_hh, dh_nx, false, true, 1.0f);
+      }
     }
   }
 }