[Fix] fix lifespan of recurrent cells
authorJihoon Lee <jhoon.it.lee@samsung.com>
Thu, 13 Jan 2022 15:25:24 +0000 (00:25 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 19 Jan 2022 10:51:43 +0000 (19:51 +0900)
This patch fixes lifespan of recurrent cells

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/layers/lstmcell.cpp
nntrainer/layers/zoneout_lstmcell.cpp

index 69655f7..32c3b76 100644 (file)
@@ -105,10 +105,10 @@ void LSTMCellLayer::finalize(InitLayerContext &context) {
   std::vector<VarGradSpecV2> out_specs;
   out_specs.push_back(
     InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
-                              TensorLifespan::FORWARD_DERIV_LIFESPAN));
+                              TensorLifespan::FORWARD_FUNC_LIFESPAN));
   out_specs.push_back(
     InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
-                              TensorLifespan::FORWARD_DERIV_LIFESPAN));
+                              TensorLifespan::FORWARD_GRAD_LIFESPAN));
   context.requestOutputs(std::move(out_specs));
 
   // weight_initializer can be set seperately. weight_ih initializer,
index a386cad..9620541 100644 (file)
@@ -129,47 +129,58 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
   const TensorDim output_cell_state_dim = input_cell_state_dim;
 
   std::vector<VarGradSpecV2> out_specs;
+  /// note: those out spec can be forward func, but for the test, it is being
+  /// kept to forward deriv lifespan
   out_specs.push_back(
     InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
                               TensorLifespan::FORWARD_DERIV_LIFESPAN));
+  //////////////////////////  TensorLifespan::FORWARD_FUNC_LIFESPAN));
   out_specs.push_back(
     InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
                               TensorLifespan::FORWARD_DERIV_LIFESPAN));
+  //////////////////////////  TensorLifespan::FORWARD_FUNC_LIFESPAN));
   context.requestOutputs(std::move(out_specs));
 
-  // weight_initializer can be set seperately. weight_ih initializer,
-  // weight_hh initializer kernel initializer & recurrent_initializer in keras
-  // for now, it is set same way.
+  // weight_initializer can be set seperately.
+  // weight_ih initializer, weight_hh initializer
+  // kernel initializer & recurrent_initializer in
+  // keras for now, it is set same way.
 
   // - weight_ih ( input to hidden )
-  //  : [ 1, 1, feature_size, NUM_GATE x unit ] -> i, f, g, o
+  //  : [ 1, 1, feature_size, NUM_GATE x unit ] ->
+  //  i, f, g, o
   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
   wt_idx[ZoneoutLSTMParams::weight_ih] =
     context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
                           weight_regularizer_constant, "weight_ih", true);
   // - weight_hh ( hidden to hidden )
-  //  : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g, o
+  //  : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g,
+  //  o
   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
   wt_idx[ZoneoutLSTMParams::weight_hh] =
     context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
                           weight_regularizer_constant, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
-      // - bias_h ( input bias, hidden bias are integrate to 1 bias )
-      //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
+      // - bias_h ( input bias, hidden bias are
+      // integrate to 1 bias )
+      //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
+      //  o
       TensorDim bias_h_dim({NUM_GATE * unit});
       wt_idx[ZoneoutLSTMParams::bias_h] =
         context.requestWeight(bias_h_dim, bias_initializer,
                               WeightRegularizer::NONE, 1.0f, "bias_h", true);
     } else {
       // - bias_ih ( input bias )
-      //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
+      //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
+      //  o
       TensorDim bias_ih_dim({NUM_GATE * unit});
       wt_idx[ZoneoutLSTMParams::bias_ih] =
         context.requestWeight(bias_ih_dim, bias_initializer,
                               WeightRegularizer::NONE, 1.0f, "bias_ih", true);
       // - bias_hh ( hidden bias )
-      //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
+      //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
+      //  o
       TensorDim bias_hh_dim({NUM_GATE * unit});
       wt_idx[ZoneoutLSTMParams::bias_hh] =
         context.requestWeight(bias_hh_dim, bias_initializer,
@@ -177,19 +188,22 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
     }
   }
 
-  /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
+  /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit
+   * ] */
   const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
   wt_idx[ZoneoutLSTMParams::ifgo] =
     context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
                           TensorLifespan::ITERATION_LIFESPAN);
 
-  /** lstm_cell_state_dim = [ batch_size, 1, 1, unit ] */
+  /** lstm_cell_state_dim = [ batch_size, 1, 1, unit
+   * ] */
   const TensorDim lstm_cell_state_dim(batch_size, 1, 1, unit);
   wt_idx[ZoneoutLSTMParams::lstm_cell_state] = context.requestTensor(
     lstm_cell_state_dim, "lstm_cell_state", Tensor::Initializer::NONE, true,
     TensorLifespan::ITERATION_LIFESPAN);
 
-  // hidden_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
+  // hidden_state_zoneout_mask_dim = [ max_timestep
+  // * batch_size, 1, 1, unit ]
   const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
                                                 unit);
   if (test) {
@@ -203,7 +217,8 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
         hidden_state_zoneout_mask_dim, "hidden_state_zoneout_mask",
         Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
   }
-  // cell_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
+  // cell_state_zoneout_mask_dim = [ max_timestep *
+  // batch_size, 1, 1, unit ]
   const TensorDim cell_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
                                               unit);
   if (test) {