[zoneout lstmcell] share zoneout mask tensors
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 5 Jan 2022 03:39:55 +0000 (12:39 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 21 Apr 2022 01:05:48 +0000 (10:05 +0900)
 - Makes zoneout_mask tensors to be shared when it is unrolled

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/zoneout_lstmcell.cpp

index b0a7845..0e9d19f 100644 (file)
@@ -215,9 +215,10 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
                             1.0f, 0.0f, "hidden_state_zoneout_mask", false);
   } else {
     wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
-      context.requestTensor(
-        hidden_state_zoneout_mask_dim, "hidden_state_zoneout_mask",
-        Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
+      context.requestTensor(hidden_state_zoneout_mask_dim,
+                            "hidden_state_zoneout_mask",
+                            Tensor::Initializer::NONE, false,
+                            TensorLifespan::ITERATION_LIFESPAN, false);
   }
   // cell_state_zoneout_mask_dim = [ max_timestep *
   // batch_size, 1, 1, unit ]
@@ -230,7 +231,8 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
   } else {
     wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
       cell_state_zoneout_mask_dim, "cell_state_zoneout_mask",
-      Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
+      Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN,
+      false);
   }
 
   acti_func.setActiFunc(hidden_state_activation_type);
@@ -326,7 +328,6 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
     hidden_state.multiply_i(hidden_state_zoneout_mask);
     prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, hidden_state,
                                1.0f);
-
     Tensor &cs_zoneout_mask =
       test
         ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
@@ -350,11 +351,11 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
 }
 
 void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
-  Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
-  const Tensor &weight_ih =
-    context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
   Tensor &outgoing_derivative =
     context.getOutgoingDerivative(INOUT_INDEX::INPUT);
+  const Tensor &weight_ih =
+    context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
+  const Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
 
   lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
 }
@@ -438,6 +439,7 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
   }
 
   Tensor d_prev_hidden_state_residual;
+  Tensor d_hidden_state_masked;
   Tensor &hs_zoneout_mask =
     test
       ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
@@ -450,7 +452,6 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
 
   d_hidden_state.multiply(prev_hidden_state_zoneout_mask,
                           d_prev_hidden_state_residual);
-  Tensor d_hidden_state_masked;
   d_hidden_state.multiply(hidden_state_zoneout_mask, d_hidden_state_masked);
 
   Tensor d_prev_cell_state_residual;
@@ -482,7 +483,7 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
 void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
                                     unsigned int batch) {
   const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+    std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
 
   context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch);
   context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch);