[lstm] implement calcGradient for bidirectional lstm
authorhyeonseok lee <hs89.lee@samsung.com>
Mon, 24 Jan 2022 05:36:41 +0000 (14:36 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 9 May 2022 05:10:37 +0000 (14:10 +0900)
 - Implement calcGradient for bidirectional lstm
 - Added test case for bidirectional lstm

close #1726

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/lstm.cpp
nntrainer/layers/lstmcell.cpp
nntrainer/layers/lstmcell_core.cpp
nntrainer/layers/lstmcell_core.h
nntrainer/layers/zoneout_lstmcell.cpp
packaging/unittest_models_v2.tar.gz
test/unittest/models/unittest_models_recurrent.cpp

index e696e75..f2a4092 100644 (file)
@@ -46,36 +46,40 @@ enum LSTMParams {
  * @brief run lstm fowarding for batch_first input
  *
  * @param NUM_GATE Number of gate which is 4 for lstm
- * @param unit number of output neurons
  * @param batch_size batch size
- * @param max_timestep maximum timestep for lstm
+ * @param feature_size feature size
+ * @param disable_bias whether to disable bias or not
+ * @param unit number of output neurons
  * @param integrate_bias integrate bias_ih, bias_hh to bias_h
  * @param acti_func activation function for memory cell, cell state
  * @param recurrent_acti_func activation function for input/output/forget
  * gate
- * @param reverse indicate forward for reverse input in bidirectional lstm
  * @param enable_dropout whether to apply dropout
  * @param dropout_rate dropout rate
+ * @param max_timestep maximum timestep for lstm
+ * @param reverse indicate forward/backward direction for input in bidirectional
+ * lstm
  * @param input_ input
- * @param weight_ih weight_ih. weight for input to hidden
- * @param weight_hh weight_hh. weight for hidden to hidden
- * @param bias_h bias_h. bias for input and hidden.
- * @param bias_ih bias_ih. bias for input
- * @param bias_hh bias_hh. bias for hidden
+ * @param weight_ih weight for input to hidden
+ * @param weight_hh weight for hidden to hidden
+ * @param bias_h bias for input and hidden.
+ * @param bias_ih bias for input
+ * @param bias_hh bias for hidden
  * @param hidden_state_ hidden state
  * @param cell_state_ cell state
  * @param ifgo_ input gate, forget gate, memory cell, output gate
  * @param mask_ dropout mask
  */
 static void batch_first_forwarding(
-  unsigned int NUM_GATE, const unsigned int unit, const unsigned int batch_size,
-  const unsigned int max_timestep, const unsigned int feature_size,
-  const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
-  ActiFunc &recurrent_acti_func, const bool reverse, const bool enable_dropout,
-  const float dropout_rate, const Tensor &input_, const Tensor &weight_ih,
-  const Tensor &weight_hh, const Tensor &bias_h, const Tensor &bias_ih,
-  const Tensor &bias_hh, Tensor &hidden_state_, Tensor &cell_state_,
-  Tensor &ifgo_, const Tensor &mask_) {
+  unsigned int NUM_GATE, const unsigned int batch_size,
+  const unsigned int feature_size, const bool disable_bias,
+  const unsigned int unit, const bool integrate_bias, ActiFunc &acti_func,
+  ActiFunc &recurrent_acti_func, const bool enable_dropout,
+  const float dropout_rate, const unsigned int max_timestep, const bool reverse,
+  const Tensor &input_, const Tensor &weight_ih, const Tensor &weight_hh,
+  const Tensor &bias_h, const Tensor &bias_ih, const Tensor &bias_hh,
+  Tensor &hidden_state_, Tensor &cell_state_, Tensor &ifgo_,
+  const Tensor &mask_) {
   hidden_state_.setZero();
   cell_state_.setZero();
 
@@ -104,15 +108,16 @@ static void batch_first_forwarding(
         prev_cell_state = Tensor(unit);
         prev_cell_state.setZero();
       } else {
-        prev_cell_state =
-          cell_state_sample.getSharedDataTensor({unit}, (t - 1) * unit);
+        prev_cell_state = cell_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
       }
-      Tensor cell_state =
-        cell_state_sample.getSharedDataTensor({unit}, t * unit);
-      Tensor ifgo =
-        ifgo_sample.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
+      Tensor cell_state = cell_state_sample.getSharedDataTensor(
+        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+      Tensor ifgo = ifgo_sample.getSharedDataTensor(
+        {NUM_GATE * unit},
+        (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
 
-      lstmcell_forwarding(unit, 1, disable_bias, integrate_bias, acti_func,
+      lstmcell_forwarding(1, unit, disable_bias, integrate_bias, acti_func,
                           recurrent_acti_func, input, prev_hidden_state,
                           prev_cell_state, hidden_state, cell_state, weight_ih,
                           weight_hh, bias_h, bias_ih, bias_hh, ifgo);
@@ -127,6 +132,172 @@ static void batch_first_forwarding(
   }
 }
 
+/**
+ * @brief calculate lstm gradient for batch_first input
+ *
+ * @param NUM_GATE Number of gate which is 4 for lstm
+ * @param batch_size batch size
+ * @param feature_size feature size
+ * @param disable_bias whether to disable bias or not
+ * @param unit number of output neurons
+ * @param integrate_bias integrate bias_ih, bias_hh to bias_h
+ * @param acti_func activation function for memory cell, cell state
+ * @param recurrent_acti_func activation function for input/output/forget
+ * gate
+ * @param return_sequences return sequeces
+ * @param bidirectional bidirectional lstm
+ * @param enable_dropout whether to apply dropout
+ * @param dropout_rate dropout rate
+ * @param max_timestep maximum timestep for lstm
+ * @param reverse indicate forward/backward direction for input in bidirectional
+ * lstm
+ * @param input_ input
+ * @param incoming_derivative derivative for output which is incoming derivative
+ * @param d_weight_ih weight_ih(weight for input to hidden) gradient
+ * @param weight_hh weight for hidden to hidden
+ * @param d_weight_hh weight_hh(weight for hidden to hidden) gradient
+ * @param d_bias_h bias_h(bias for input and hidden) gradient
+ * @param d_bias_ih bias_ih(bias for input) gradient
+ * @param d_bias_hh bias_hh(bias for hidden) gradient
+ * @param hidden_state_ hidden state
+ * @param d_hidden_state_ hidden state gradient
+ * @param cell_state_ cell state
+ * @param d_cell_state_ cell state gradient
+ * @param ifgo_ input gate, forget gate, memory cell, output gate
+ * @param d_ifgo_ gradient for input gate, forget gate, memory cell, output gate
+ * @param mask_ dropout mask
+ */
+void batch_first_calcGradient(
+  unsigned int NUM_GATE, const unsigned int batch_size,
+  const unsigned int feature_size, const bool disable_bias,
+  const unsigned int unit, const bool integrate_bias, ActiFunc &acti_func,
+  ActiFunc &recurrent_acti_func, const bool return_sequences,
+  const bool bidirectional, const bool enable_dropout, const float dropout_rate,
+  const unsigned int max_timestep, const bool reverse, const Tensor &input_,
+  const Tensor &incoming_derivative, Tensor &d_weight_ih,
+  const Tensor &weight_hh, Tensor &d_weight_hh, Tensor &d_bias_h,
+  Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &hidden_state_,
+  Tensor &d_hidden_state_, const Tensor &cell_state_, Tensor &d_cell_state_,
+  const Tensor &ifgo_, Tensor &d_ifgo_, const Tensor &mask_) {
+  const unsigned int bidirectional_constant = bidirectional ? 2 : 1;
+
+  d_weight_ih.setZero();
+  d_weight_hh.setZero();
+  if (!disable_bias) {
+    if (integrate_bias) {
+      d_bias_h.setZero();
+    } else {
+      d_bias_ih.setZero();
+      d_bias_hh.setZero();
+    }
+  }
+
+  d_cell_state_.setZero();
+  d_hidden_state_.setZero();
+
+  if (return_sequences && !bidirectional && !reverse) {
+    std::copy(incoming_derivative.getData(),
+              incoming_derivative.getData() + incoming_derivative.size(),
+              d_hidden_state_.getData());
+  } else {
+    unsigned int end_timestep = return_sequences ? max_timestep : 1;
+    for (unsigned int batch = 0; batch < batch_size; ++batch) {
+      for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
+        Tensor d_hidden_state_sample = d_hidden_state_.getSharedDataTensor(
+          {unit}, batch * max_timestep * unit +
+                    (return_sequences ? 0 : max_timestep - 1) * unit +
+                    timestep * unit);
+        Tensor incoming_derivative_sample =
+          incoming_derivative.getSharedDataTensor(
+            {unit}, batch * (return_sequences ? max_timestep : 1) *
+                        bidirectional_constant * unit +
+                      timestep * bidirectional_constant * unit +
+                      (reverse ? unit : 0));
+        d_hidden_state_sample.add_i(incoming_derivative_sample);
+      }
+    }
+  }
+
+  if (enable_dropout) {
+    d_hidden_state_.multiply_i(mask_);
+  }
+
+  for (unsigned int batch = 0; batch < batch_size; ++batch) {
+    const Tensor input_sample = input_.getBatchSlice(batch, 1);
+
+    const Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
+    Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
+    const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
+    Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
+
+    const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
+    Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
+
+    Tensor input;
+    Tensor prev_hidden_state;
+    Tensor d_prev_hidden_state;
+    Tensor prev_cell_state;
+    Tensor d_prev_cell_state;
+    Tensor d_hidden_state;
+    Tensor cell_state;
+    Tensor d_cell_state;
+
+    for (int t = max_timestep - 1; t > -1; t--) {
+      input = input_sample.getSharedDataTensor(
+        {feature_size}, (reverse ? max_timestep - 1 - t : t) * feature_size);
+
+      if (!t) {
+        prev_hidden_state = Tensor(unit);
+        prev_hidden_state.setZero();
+        d_prev_hidden_state = Tensor(unit);
+        d_prev_hidden_state.setZero();
+      } else {
+        prev_hidden_state = hidden_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+        d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+      }
+      d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+      if (!t) {
+        prev_cell_state = Tensor(unit);
+        prev_cell_state.setZero();
+        d_prev_cell_state = Tensor(unit);
+        d_prev_cell_state.setZero();
+      } else {
+        prev_cell_state = cell_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+        d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+      }
+      cell_state = cell_state_sample.getSharedDataTensor(
+        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+      d_cell_state = d_cell_state_sample.getSharedDataTensor(
+        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+      Tensor ifgo = ifgo_sample.getSharedDataTensor(
+        {NUM_GATE * unit},
+        (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+      Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
+        {NUM_GATE * unit},
+        (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+
+      // Temporary variable for d_prev_hidden_state. d_prev_hidden_state already
+      // have precalculated values from incomming derivatives
+      Tensor d_prev_hidden_state_temp;
+
+      lstmcell_calcGradient(1, unit, disable_bias, integrate_bias, acti_func,
+                            recurrent_acti_func, input, prev_hidden_state,
+                            d_prev_hidden_state_temp, prev_cell_state,
+                            d_prev_cell_state, d_hidden_state, cell_state,
+                            d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
+                            d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+      d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
+    }
+  }
+}
+
 LSTMLayer::LSTMLayer() :
   LayerImpl(),
   lstm_props(props::Unit(), props::IntegrateBias(),
@@ -345,7 +516,7 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(lstm_props).get();
 
-  unsigned int bidirectional_constant = bidirectional ? 2 : 1;
+  const unsigned int bidirectional_constant = bidirectional ? 2 : 1;
   bool enable_dropout = dropout_rate > epsilon && training;
 
   const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
@@ -375,11 +546,11 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
                    ? context.getTensor(wt_idx[LSTMParams::dropout_mask])
                    : empty;
 
-  batch_first_forwarding(
-    NUM_GATE, unit, batch_size, max_timestep, feature_size, disable_bias,
-    integrate_bias, acti_func, recurrent_acti_func, false, enable_dropout,
-    dropout_rate, input, weight_ih, weight_hh, bias_h, bias_ih, bias_hh,
-    hidden_state, cell_state, ifgo, mask);
+  batch_first_forwarding(NUM_GATE, batch_size, feature_size, disable_bias, unit,
+                         integrate_bias, acti_func, recurrent_acti_func,
+                         enable_dropout, dropout_rate, max_timestep, false,
+                         input, weight_ih, weight_hh, bias_h, bias_ih, bias_hh,
+                         hidden_state, cell_state, ifgo, mask);
 
   if (bidirectional) {
     const Tensor &reverse_weight_ih =
@@ -406,10 +577,10 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
     Tensor &reverse_ifgo = context.getTensor(wt_idx[LSTMParams::reverse_ifgo]);
 
     batch_first_forwarding(
-      NUM_GATE, unit, batch_size, max_timestep, feature_size, disable_bias,
-      integrate_bias, acti_func, recurrent_acti_func, true, enable_dropout,
-      dropout_rate, input, reverse_weight_ih, reverse_weight_hh, reverse_bias_h,
-      reverse_bias_ih, reverse_bias_hh, reverse_hidden_state,
+      NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
+      acti_func, recurrent_acti_func, enable_dropout, dropout_rate,
+      max_timestep, true, input, reverse_weight_ih, reverse_weight_hh,
+      reverse_bias_h, reverse_bias_ih, reverse_bias_hh, reverse_hidden_state,
       reverse_cell_state, reverse_ifgo, mask);
   }
 
@@ -417,11 +588,9 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
     std::copy(hidden_state.getData(),
               hidden_state.getData() + hidden_state.size(), output.getData());
   } else {
-    unsigned int start_timestep = 0;
     unsigned int end_timestep = return_sequences ? max_timestep : 1;
     for (unsigned int batch = 0; batch < batch_size; ++batch) {
-      for (unsigned int timestep = start_timestep; timestep < end_timestep;
-           ++timestep) {
+      for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
         float *hidden_state_data = hidden_state.getAddress(
           batch * max_timestep * unit +
           (return_sequences ? 0 : (max_timestep - 1) * unit) + timestep * unit);
@@ -447,11 +616,23 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
 }
 
 void LSTMLayer::calcDerivative(RunLayerContext &context) {
+  const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
+
   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
   const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
   const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
 
-  lstmcell_calcDerivative(d_ifgos, weight_ih, outgoing_derivative);
+  lstmcell_calcDerivative(outgoing_derivative, weight_ih, d_ifgos);
+
+  if (bidirectional) {
+    const Tensor &reverse_weight_ih =
+      context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]);
+    const Tensor &reverse_d_ifgos =
+      context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]);
+
+    lstmcell_calcDerivative(outgoing_derivative, reverse_weight_ih,
+                            reverse_d_ifgos, 1.0f);
+  }
 }
 
 void LSTMLayer::calcGradient(RunLayerContext &context) {
@@ -462,17 +643,17 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
   const bool integrate_bias = std::get<props::IntegrateBias>(lstm_props).get();
   const bool return_sequences =
     std::get<props::ReturnSequences>(lstm_props).get();
+  const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(lstm_props).get();
 
-  unsigned int start_timestep = max_timestep - 1;
-  int end_timestep = -1;
+  bool enable_dropout = dropout_rate > epsilon;
 
-  const Tensor &inputs = context.getInput(SINGLE_INOUT_IDX);
+  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
   const Tensor &incoming_derivative =
     context.getIncomingDerivative(SINGLE_INOUT_IDX);
-  const TensorDim input_dim = inputs.getDim();
+  const TensorDim input_dim = input.getDim();
   const unsigned int batch_size = input_dim.batch();
   const unsigned int feature_size = input_dim.width();
 
@@ -490,131 +671,77 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
                         ? context.getWeightGrad(wt_idx[LSTMParams::bias_hh])
                         : empty;
 
-  Tensor &hs = context.getTensor(wt_idx[LSTMParams::hidden_state]);
-  Tensor &d_hs = context.getTensorGrad(wt_idx[LSTMParams::hidden_state]);
-  Tensor &cs = context.getTensor(wt_idx[LSTMParams::cell_state]);
-  Tensor &d_cs = context.getTensorGrad(wt_idx[LSTMParams::cell_state]);
+  const Tensor &hidden_state =
+    context.getTensor(wt_idx[LSTMParams::hidden_state]);
+  Tensor &d_hidden_state =
+    context.getTensorGrad(wt_idx[LSTMParams::hidden_state]);
+  const Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]);
+  Tensor &d_cell_state = context.getTensorGrad(wt_idx[LSTMParams::cell_state]);
 
-  Tensor &ifgos = context.getTensor(wt_idx[LSTMParams::ifgo]);
-  Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
+  const Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]);
+  Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
 
-  d_weight_ih.setZero();
-  d_weight_hh.setZero();
-  if (!disable_bias) {
-    if (integrate_bias) {
-      d_bias_h.setZero();
-    } else {
-      d_bias_ih.setZero();
-      d_bias_hh.setZero();
-    }
-  }
+  const Tensor &mask = enable_dropout
+                         ? context.getTensor(wt_idx[LSTMParams::dropout_mask])
+                         : empty;
 
-  d_cs.setZero();
-  d_hs.setZero();
+  batch_first_calcGradient(
+    NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
+    acti_func, recurrent_acti_func, return_sequences, bidirectional,
+    enable_dropout, dropout_rate, max_timestep, false, input,
+    incoming_derivative, d_weight_ih, weight_hh, d_weight_hh, d_bias_h,
+    d_bias_ih, d_bias_hh, hidden_state, d_hidden_state, cell_state,
+    d_cell_state, ifgo, d_ifgo, mask);
 
-  if (return_sequences) {
-    std::copy(incoming_derivative.getData(),
-              incoming_derivative.getData() + incoming_derivative.size(),
-              d_hs.getData());
-  } else {
-    for (unsigned int batch = 0; batch < batch_size; ++batch) {
-      Tensor data = d_hs.getSharedDataTensor(
-        {unit}, batch * max_timestep * unit + start_timestep * unit);
-
-      Tensor rdata =
-        incoming_derivative.getSharedDataTensor({unit}, batch * unit);
-      /// @note this is not copying from start ~ end but only start time
-      /// step
-      // This is copying for self rolling as well as last recurrent
-      // unrolled.
-      if ((unsigned)start_timestep + 1 == max_timestep) {
-        data.fill(rdata);
-      } else {
-        data.add_i(rdata);
-      }
-    }
-  }
-
-  if (dropout_rate > epsilon) {
-    d_hs.multiply_i(context.getTensor(wt_idx[LSTMParams::dropout_mask]));
-  }
-
-  for (unsigned int batch = 0; batch < batch_size; ++batch) {
-    const Tensor input_batch = inputs.getBatchSlice(batch, 1);
-
-    Tensor hs_batch = hs.getBatchSlice(batch, 1);
-    Tensor d_hidden_state_batch = d_hs.getBatchSlice(batch, 1);
-    Tensor cs_batch = cs.getBatchSlice(batch, 1);
-    Tensor d_cell_state_batch = d_cs.getBatchSlice(batch, 1);
-
-    Tensor ifgo_batch = ifgos.getBatchSlice(batch, 1);
-    Tensor d_ifgo_batch = d_ifgos.getBatchSlice(batch, 1);
-
-    Tensor input;
-    Tensor prev_hidden_state;
-    Tensor d_prev_hidden_state;
-    Tensor prev_cell_state;
-    Tensor d_prev_cell_state;
-    Tensor d_hidden_state;
-    Tensor cell_state;
-    Tensor d_cell_state;
-
-    for (int t = start_timestep; t > end_timestep; t--) {
-      if (input_batch.height() != 1)
-        input =
-          input_batch.getSharedDataTensor({feature_size}, t * feature_size);
-      else
-        input = input_batch;
-
-      if (!t) {
-        prev_hidden_state = Tensor(unit);
-        prev_hidden_state.setZero();
-        d_prev_hidden_state = Tensor(unit);
-        d_prev_hidden_state.setZero();
-      } else {
-        prev_hidden_state =
-          hs_batch.getSharedDataTensor({unit}, (t - 1) * unit);
-        d_prev_hidden_state =
-          d_hidden_state_batch.getSharedDataTensor({unit}, (t - 1) * unit);
-      }
-      d_hidden_state =
-        d_hidden_state_batch.getSharedDataTensor({unit}, t * unit);
-
-      if (!t) {
-        prev_cell_state = Tensor(unit);
-        prev_cell_state.setZero();
-        d_prev_cell_state = Tensor(unit);
-        d_prev_cell_state.setZero();
-      } else {
-        prev_cell_state = cs_batch.getSharedDataTensor({unit}, (t - 1) * unit);
-        d_prev_cell_state =
-          d_cell_state_batch.getSharedDataTensor({unit}, (t - 1) * unit);
-      }
-      cell_state = cs_batch.getSharedDataTensor({unit}, t * unit);
-      d_cell_state = d_cell_state_batch.getSharedDataTensor({unit}, t * unit);
-
-      Tensor ifgo =
-        ifgo_batch.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
-      Tensor d_ifgo = d_ifgo_batch.getSharedDataTensor({unit * NUM_GATE},
-                                                       unit * t * NUM_GATE);
-
-      // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
-      // already have precalculated values from incomming derivatives
-      Tensor d_prev_hidden_state_temp;
+  if (bidirectional) {
+    Tensor &reverse_d_weight_ih =
+      context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_ih]);
+    const Tensor &reverse_weight_hh =
+      context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]);
+    Tensor &reverse_d_weight_hh =
+      context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_hh]);
+    Tensor &reverse_d_bias_h =
+      !disable_bias && integrate_bias
+        ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_h])
+        : empty;
+    Tensor &reverse_d_bias_ih =
+      !disable_bias && !integrate_bias
+        ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_ih])
+        : empty;
+    Tensor &reverse_d_bias_hh =
+      !disable_bias && !integrate_bias
+        ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_hh])
+        : empty;
 
-      lstmcell_calcGradient(unit, 1, disable_bias, integrate_bias, acti_func,
-                            recurrent_acti_func, input, prev_hidden_state,
-                            d_prev_hidden_state_temp, prev_cell_state,
-                            d_prev_cell_state, d_hidden_state, cell_state,
-                            d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
-                            d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
-      d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
-    }
+    const Tensor &reverse_hidden_state =
+      context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
+    Tensor &reverse_d_hidden_state =
+      context.getTensorGrad(wt_idx[LSTMParams::reverse_hidden_state]);
+    const Tensor &reverse_cell_state =
+      context.getTensor(wt_idx[LSTMParams::reverse_cell_state]);
+    Tensor &reverse_d_cell_state =
+      context.getTensorGrad(wt_idx[LSTMParams::reverse_cell_state]);
+
+    const Tensor &reverse_ifgo =
+      context.getTensor(wt_idx[LSTMParams::reverse_ifgo]);
+    Tensor &reverse_d_ifgo =
+      context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]);
+
+    batch_first_calcGradient(
+      NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
+      acti_func, recurrent_acti_func, return_sequences, bidirectional,
+      enable_dropout, dropout_rate, max_timestep, true, input,
+      incoming_derivative, reverse_d_weight_ih, reverse_weight_hh,
+      reverse_d_weight_hh, reverse_d_bias_h, reverse_d_bias_ih,
+      reverse_d_bias_hh, reverse_hidden_state, reverse_d_hidden_state,
+      reverse_cell_state, reverse_d_cell_state, reverse_ifgo, reverse_d_ifgo,
+      mask);
   }
 }
 
 void LSTMLayer::setBatch(RunLayerContext &context, unsigned int batch) {
   const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
+  const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
 
   context.updateTensor(wt_idx[LSTMParams::hidden_state], batch);
   context.updateTensor(wt_idx[LSTMParams::cell_state], batch);
@@ -626,7 +753,7 @@ void LSTMLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     context.updateTensor(wt_idx[LSTMParams::reverse_ifgo], batch);
   }
 
-  if (std::get<props::DropOutRate>(lstm_props).get() > epsilon) {
+  if (dropout_rate > epsilon) {
     context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch);
   }
 }
index 6e84220..4c9d4e3 100644 (file)
@@ -219,7 +219,7 @@ void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
 
   Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
 
-  lstmcell_forwarding(unit, batch_size, disable_bias, integrate_bias, acti_func,
+  lstmcell_forwarding(batch_size, unit, disable_bias, integrate_bias, acti_func,
                       recurrent_acti_func, input, prev_hidden_state,
                       prev_cell_state, hidden_state, cell_state, weight_ih,
                       weight_hh, bias_h, bias_ih, bias_hh, ifgo);
@@ -239,7 +239,7 @@ void LSTMCellLayer::calcDerivative(RunLayerContext &context) {
   Tensor &outgoing_derivative =
     context.getOutgoingDerivative(INOUT_INDEX::INPUT);
 
-  lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
+  lstmcell_calcDerivative(outgoing_derivative, weight_ih, d_ifgo);
 }
 
 void LSTMCellLayer::calcGradient(RunLayerContext &context) {
@@ -317,7 +317,7 @@ void LSTMCellLayer::calcGradient(RunLayerContext &context) {
   }
 
   lstmcell_calcGradient(
-    unit, batch_size, disable_bias, integrate_bias, acti_func,
+    batch_size, unit, disable_bias, integrate_bias, acti_func,
     recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
     prev_cell_state, d_prev_cell_state,
     dropout_rate > epsilon ? d_hidden_state_masked : d_hidden_state, cell_state,
index 5cf5b73..827e581 100644 (file)
@@ -17,7 +17,7 @@
 
 namespace nntrainer {
 
-void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
+void lstmcell_forwarding(const unsigned int batch_size, const unsigned int unit,
                          const bool disable_bias, const bool integrate_bias,
                          ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
                          const Tensor &input, const Tensor &prev_hidden_state,
@@ -59,13 +59,14 @@ void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
   hidden_state.multiply_i_strided(output_gate);
 }
 
-void lstmcell_calcDerivative(const Tensor &d_ifgo, const Tensor &weight_ih,
-                             Tensor &outgoing_derivative) {
-  d_ifgo.dot(weight_ih, outgoing_derivative, false, true);
+void lstmcell_calcDerivative(Tensor &outgoing_derivative,
+                             const Tensor &weight_ih, const Tensor &d_ifgo,
+                             const float alpha) {
+  d_ifgo.dot(weight_ih, outgoing_derivative, false, true, alpha);
 }
 
 void lstmcell_calcGradient(
-  const unsigned int unit, const unsigned int batch_size,
+  const unsigned int batch_size, const unsigned int unit,
   const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
   ActiFunc &recurrent_acti_func, const Tensor &input,
   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
index b39d2e8..3a3b7d4 100644 (file)
@@ -22,8 +22,26 @@ namespace nntrainer {
 /**
  * @brief lstm cell forwarding implementation
  *
+ * @param batch_size batch size
+ * @param unit number of output neurons
+ * @param disable_bias whether to disable bias or not
+ * @param integrate_bias integrate bias_ih, bias_hh to bias_h
+ * @param acti_func activation function for memory cell, cell state
+ * @param recurrent_acti_func activation function for input/output/forget
+ * gate
+ * @param input input
+ * @param prev_hidden_state previous hidden state
+ * @param prev_cell_state previous cell state
+ * @param hidden_state hidden state
+ * @param cell_state cell state
+ * @param weight_ih weight for input to hidden
+ * @param weight_hh weight for hidden to hidden
+ * @param bias_h bias for input and hidden.
+ * @param bias_ih bias for input
+ * @param bias_hh bias for hidden
+ * @param ifgo input gate, forget gate, memory cell, output gate
  */
-void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
+void lstmcell_forwarding(const unsigned int batch_size, const unsigned int unit,
                          const bool disable_bias, const bool integrate_bias,
                          ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
                          const Tensor &input, const Tensor &prev_hidden_state,
@@ -36,16 +54,44 @@ void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
 /**
  * @brief lstm cell calculate derivative implementation
  *
+ * @param outgoing_derivative derivative for input
+ * @param weight_ih weight for input to hidden
+ * @param d_ifgo gradient for input gate, forget gate, memory cell, output gate
+ * @param alpha value to be scale outgoing_derivative
  */
-void lstmcell_calcDerivative(const Tensor &d_ifgo, const Tensor &weight_ih,
-                             Tensor &outgoing_derivative);
+void lstmcell_calcDerivative(Tensor &outgoing_derivative,
+                             const Tensor &weight_ih, const Tensor &d_ifgo,
+                             const float alpha = 0.0f);
 
 /**
  * @brief lstm cell calculate gradient implementation
  *
+ * @param batch_size batch size
+ * @param unit number of output neurons
+ * @param disable_bias whether to disable bias or not
+ * @param integrate_bias integrate bias_ih, bias_hh to bias_h
+ * @param acti_func activation function for memory cell, cell state
+ * @param recurrent_acti_func activation function for input/output/forget
+ * gate
+ * @param input input
+ * @param prev_hidden_state previous hidden state
+ * @param d_prev_hidden_state previous hidden state gradient
+ * @param prev_cell_state previous cell state
+ * @param d_prev_cell_state previous cell state gradient
+ * @param d_hidden_state hidden state gradient
+ * @param cell_state cell state
+ * @param d_cell_state cell state gradient
+ * @param d_weight_ih weight_ih(weight for input to hidden) gradient
+ * @param weight_hh weight for hidden to hidden
+ * @param d_weight_hh weight_hh(weight for hidden to hidden) gradient
+ * @param d_bias_h bias_h(bias for input and hidden) gradient
+ * @param d_bias_ih bias_ih(bias for input) gradient
+ * @param d_bias_hh bias_hh(bias for hidden) gradient
+ * @param ifgo input gate, forget gate, memory cell, output gate
+ * @param d_ifgo gradient for input gate, forget gate, memory cell, output gate
  */
 void lstmcell_calcGradient(
-  const unsigned int unit, const unsigned int batch_size,
+  const unsigned int batch_size, const unsigned int unit,
   const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
   ActiFunc &recurrent_acti_func, const Tensor &input,
   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
index 1392d5a..46f1b6d 100644 (file)
@@ -323,7 +323,7 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
       : empty;
 
   // Todo: pass lstm_cell_state as a argument at inference
-  lstmcell_forwarding(unit, batch_size, disable_bias, integrate_bias, acti_func,
+  lstmcell_forwarding(batch_size, unit, disable_bias, integrate_bias, acti_func,
                       recurrent_acti_func, input, prev_hidden_state,
                       prev_cell_state, hidden_state,
                       training && enable_cell_state_zoneout ? lstm_cell_state
@@ -388,7 +388,7 @@ void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
     context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
   const Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
 
-  lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
+  lstmcell_calcDerivative(outgoing_derivative, weight_ih, d_ifgo);
 }
 
 void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
@@ -519,7 +519,7 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
   }
 
   lstmcell_calcGradient(
-    unit, batch_size, disable_bias, integrate_bias, acti_func,
+    batch_size, unit, disable_bias, integrate_bias, acti_func,
     recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
     prev_cell_state, d_prev_cell_state,
     enable_hidden_state_zoneout ? d_hidden_state_masked : d_hidden_state,
index 40a7894..ae70951 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index ae812f9..12441ac 100644 (file)
@@ -77,7 +77,7 @@ IniWrapper fc_unroll_single__2(
   });
 
 std::unique_ptr<NeuralNetwork> makeFC() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=1"});
 
   auto outer_graph = makeGraph({
@@ -107,7 +107,7 @@ std::unique_ptr<NeuralNetwork> makeFC() {
 }
 
 std::unique_ptr<NeuralNetwork> makeFCClipped() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=1"});
 
   auto outer_graph = makeGraph({
@@ -138,7 +138,7 @@ std::unique_ptr<NeuralNetwork> makeFCClipped() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeSingleLSTM() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -156,7 +156,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleLSTM() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeStackedLSTM() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -175,49 +175,49 @@ static std::unique_ptr<NeuralNetwork> makeStackedLSTM() {
   return nn;
 }
 
-// static std::unique_ptr<NeuralNetwork> makeSingleBidirectionalLSTM() {
-//   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
-//   nn->setProperty({"batch_size=3"});
-
-//   auto outer_graph = makeGraph({
-//     {"input", {"name=input", "input_shape=1:2:2"}},
-//     {"lstm",
-//      {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true",
-//       "bidirectional=true"}},
-//     {"mse", {"name=loss", "input_layers=a1"}},
-//   });
-//   for (auto &node : outer_graph) {
-//     nn->addLayer(node);
-//   }
-
-//   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate =
-//   0.1"})); return nn;
-// }
-
-// static std::unique_ptr<NeuralNetwork> makeStackedBidirectionalLSTM() {
-//   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
-//   nn->setProperty({"batch_size=3"});
-
-//   auto outer_graph = makeGraph({
-//     {"input", {"name=input", "input_shape=1:2:2"}},
-//     {"lstm",
-//      {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true",
-//       "bidirectional=true"}},
-//     {"lstm",
-//      {"name=a2", "unit=2", "integrate_bias=false", "return_sequences=true",
-//       "bidirectional=true"}},
-//     {"mse", {"name=loss"}},
-//   });
-//   for (auto &node : outer_graph) {
-//     nn->addLayer(node);
-//   }
-
-//   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate =
-//   0.1"})); return nn;
-// }
+static std::unique_ptr<NeuralNetwork> makeSingleBidirectionalLSTM() {
+  auto nn = std::make_unique<NeuralNetwork>();
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:2:2"}},
+    {"lstm",
+     {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true",
+      "bidirectional=true"}},
+    {"mse", {"name=loss"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
+static std::unique_ptr<NeuralNetwork> makeStackedBidirectionalLSTM() {
+  auto nn = std::make_unique<NeuralNetwork>();
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:2:2"}},
+    {"lstm",
+     {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true",
+      "bidirectional=true"}},
+    {"lstm",
+     {"name=a2", "unit=2", "integrate_bias=false", "return_sequences=true",
+      "bidirectional=true"}},
+    {"mse", {"name=loss"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
 
 static std::unique_ptr<NeuralNetwork> makeSingleLSTMCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -257,7 +257,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleLSTMCell() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeStackedLSTMCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -310,7 +310,7 @@ static std::unique_ptr<NeuralNetwork> makeStackedLSTMCell() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeSingleZoneoutLSTMCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=1"});
 
   auto outer_graph = makeGraph({
@@ -352,7 +352,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleZoneoutLSTMCell() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeStackedZoneoutLSTMCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=1"});
 
   auto outer_graph = makeGraph({
@@ -410,7 +410,7 @@ static std::unique_ptr<NeuralNetwork> makeStackedZoneoutLSTMCell() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeSingleRNNCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -440,7 +440,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleRNNCell() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeStackedRNNCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -472,7 +472,7 @@ static std::unique_ptr<NeuralNetwork> makeStackedRNNCell() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeSingleGRUCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -509,7 +509,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleGRUCell() {
 }
 
 static std::unique_ptr<NeuralNetwork> makeStackedGRUCell() {
-  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  auto nn = std::make_unique<NeuralNetwork>();
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
@@ -567,10 +567,10 @@ INSTANTIATE_TEST_CASE_P(
                  ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::ALL_V2),
-    // mkModelTc_V2(makeSingleBidirectionalLSTM, "bidirectional_lstm_single",
-    //              ModelTestOption::ALL_V2),
-    // mkModelTc_V2(makeStackedBidirectionalLSTM, "bidirectional_lstm_stacked",
-    //              ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeSingleBidirectionalLSTM, "bidirectional_lstm_single",
+                 ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeStackedBidirectionalLSTM, "bidirectional_lstm_stacked",
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleLSTMCell, "lstmcell_single",
                  ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedLSTMCell, "lstmcell_stacked",