[layer] zoneout lstm bug fix
authorParichay Kapoor <pk.kapoor@samsung.com>
Mon, 6 Dec 2021 17:10:18 +0000 (02:10 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 7 Dec 2021 06:15:53 +0000 (15:15 +0900)
zoneout lstm bug fix for handling the scenario where mask rate is set to
zero for either hidden or cell.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/zoneout_lstmcell.cpp

index 921b1e5..abe83d4 100644 (file)
@@ -90,10 +90,6 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
                 std::invalid_argument)
     << "unit property missing for zoneout_lstmcell layer";
   const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
-  const float hidden_state_zoneout_rate =
-    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
-  const float cell_state_zoneout_rate =
-    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
   const bool test = std::get<Test>(zoneout_lstmcell_props);
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props);
@@ -166,7 +162,7 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
       context.requestWeight(hidden_state_zoneout_mask_dim,
                             Tensor::Initializer::NONE, WeightRegularizer::NONE,
                             1.0f, "hidden_state_zoneout_mask", false);
-  } else if (hidden_state_zoneout_rate > epsilon) {
+  } else {
     wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
       context.requestTensor(
         hidden_state_zoneout_mask_dim, "hidden_state_zoneout_mask",
@@ -179,7 +175,7 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
     wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestWeight(
       cell_state_zoneout_mask_dim, Tensor::Initializer::NONE,
       WeightRegularizer::NONE, 1.0f, "cell_state_zoneout_mask", false);
-  } else if (cell_state_zoneout_rate > epsilon) {
+  } else {
     wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
       cell_state_zoneout_mask_dim, "cell_state_zoneout_mask",
       Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
@@ -331,75 +327,69 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
                                init_lstm_context::getTensors(tensors));
   lstmcellcorelayer.forwarding(core_context, training);
 
-  if (hidden_state_zoneout_rate > epsilon) {
-    if (training) {
-      Tensor &hidden_state_zoneout_mask =
-        test ? context.getWeight(
-                 wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
-             : context.getTensor(
-                 wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
-      hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-      Tensor next_hidden_state_zoneout_mask =
-        hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
-      next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
-      Tensor prev_hidden_state_zoneout_mask;
-      if (!test) {
-        prev_hidden_state_zoneout_mask =
-          next_hidden_state_zoneout_mask.zoneout_mask(
-            hidden_state_zoneout_rate);
-      } else {
-        next_hidden_state_zoneout_mask.multiply(-1.0f,
-                                                prev_hidden_state_zoneout_mask);
-        prev_hidden_state_zoneout_mask.add_i(1.0f);
-      }
-
-      Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx);
-      hidden_state_origin.reshape({max_timestep, 1, batch_size, unit});
-      Tensor next_hidden_state_origin =
-        hidden_state_origin.getBatchSlice(timestep, 1);
-      next_hidden_state_origin.reshape({batch_size, 1, 1, unit});
-
-      next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask,
-                                        next_hidden_state);
-      prev_hidden_state.multiply(prev_hidden_state_zoneout_mask,
-                                 next_hidden_state, 1.0f);
+  if (training) {
+    Tensor &hidden_state_zoneout_mask =
+      test ? context.getWeight(
+               wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
+           : context.getTensor(
+               wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
+    hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_hidden_state_zoneout_mask =
+      hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
+    next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+    Tensor prev_hidden_state_zoneout_mask;
+    if (!test) {
+      prev_hidden_state_zoneout_mask =
+        next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+    } else {
+      next_hidden_state_zoneout_mask.multiply(-1.0f,
+                                              prev_hidden_state_zoneout_mask);
+      prev_hidden_state_zoneout_mask.add_i(1.0f);
     }
-    // Todo: zoneout at inference
+
+    Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx);
+    hidden_state_origin.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_hidden_state_origin =
+      hidden_state_origin.getBatchSlice(timestep, 1);
+    next_hidden_state_origin.reshape({batch_size, 1, 1, unit});
+
+    next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask,
+                                      next_hidden_state);
+    prev_hidden_state.multiply(prev_hidden_state_zoneout_mask,
+                               next_hidden_state, 1.0f);
   }
-  if (cell_state_zoneout_rate > epsilon) {
-    if (training) {
-      Tensor &cell_state_zoneout_mask =
-        test ? context.getWeight(
-                 wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
-             : context.getTensor(
-                 wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
-      cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-      Tensor next_cell_state_zoneout_mask =
-        cell_state_zoneout_mask.getBatchSlice(timestep, 1);
-      next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
-      Tensor prev_cell_state_zoneout_mask;
-      if (!test) {
-        prev_cell_state_zoneout_mask =
-          next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
-      } else {
-        next_cell_state_zoneout_mask.multiply(-1.0f,
-                                              prev_cell_state_zoneout_mask);
-        prev_cell_state_zoneout_mask.add_i(1.0f);
-      }
-
-      Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx);
-      cell_state_origin.reshape({max_timestep, 1, batch_size, unit});
-      Tensor next_cell_state_origin =
-        cell_state_origin.getBatchSlice(timestep, 1);
-      next_cell_state_origin.reshape({batch_size, 1, 1, unit});
-
-      next_cell_state_origin.multiply(next_cell_state_zoneout_mask,
-                                      next_cell_state);
-      prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state,
-                               1.0f);
+
+  if (training) {
+    Tensor &cell_state_zoneout_mask =
+      test
+        ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
+        : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
+    cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_cell_state_zoneout_mask =
+      cell_state_zoneout_mask.getBatchSlice(timestep, 1);
+    next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+    Tensor prev_cell_state_zoneout_mask;
+    if (!test) {
+      prev_cell_state_zoneout_mask =
+        next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+    } else {
+      next_cell_state_zoneout_mask.multiply(-1.0f,
+                                            prev_cell_state_zoneout_mask);
+      prev_cell_state_zoneout_mask.add_i(1.0f);
     }
-    // Todo: zoneout at inference
+
+    Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx);
+    cell_state_origin.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_cell_state_origin =
+      cell_state_origin.getBatchSlice(timestep, 1);
+    next_cell_state_origin.reshape({batch_size, 1, 1, unit});
+
+    next_cell_state_origin.multiply(next_cell_state_zoneout_mask,
+                                    next_cell_state);
+    prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state,
+                             1.0f);
   }
+  // Todo: zoneout at inference
 
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
   output.copyData(next_hidden_state);
@@ -481,81 +471,77 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
   Tensor prev_cell_state_derivative;
   Tensor prev_hidden_state_derivative_residual;
   Tensor prev_cell_state_derivative_residual;
-  if (hidden_state_zoneout_rate > epsilon) {
-    Tensor &hidden_state_zoneout_mask =
-      test ? context.getWeight(
-               wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
-           : context.getTensor(
-               wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
-    hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_hidden_state_zoneout_mask =
-      hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
-    next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
-    Tensor prev_hidden_state_zoneout_mask;
-    if (!test) {
-      prev_hidden_state_zoneout_mask =
-        next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
-    } else {
-      next_hidden_state_zoneout_mask.multiply(-1.0f,
-                                              prev_hidden_state_zoneout_mask);
-      prev_hidden_state_zoneout_mask.add_i(1.0f);
-    }
 
-    if (timestep) {
-      prev_hidden_state_derivative =
-        hidden_state_derivative.getBatchSlice(timestep - 1, 1);
-      prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
-      next_hidden_state_derivative.multiply(
-        prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual);
-    }
+  Tensor &hidden_state_zoneout_mask =
+    test
+      ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
+      : context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
+  hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_hidden_state_zoneout_mask =
+    hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
+  next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+  Tensor prev_hidden_state_zoneout_mask;
+  if (!test) {
+    prev_hidden_state_zoneout_mask =
+      next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+  } else {
+    next_hidden_state_zoneout_mask.multiply(-1.0f,
+                                            prev_hidden_state_zoneout_mask);
+    prev_hidden_state_zoneout_mask.add_i(1.0f);
+  }
 
-    Tensor &hidden_state_origin_derivative =
-      context.getTensorGrad(hidden_state_origin_idx);
-    hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_hidden_state_origin_derivative =
-      hidden_state_origin_derivative.getBatchSlice(timestep, 1);
-    next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+  if (timestep) {
+    prev_hidden_state_derivative =
+      hidden_state_derivative.getBatchSlice(timestep - 1, 1);
+    prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+    next_hidden_state_derivative.multiply(
+      prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual);
+  }
 
-    next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask,
-                                          next_hidden_state_origin_derivative);
+  Tensor &hidden_state_origin_derivative =
+    context.getTensorGrad(hidden_state_origin_idx);
+  hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_hidden_state_origin_derivative =
+    hidden_state_origin_derivative.getBatchSlice(timestep, 1);
+  next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+
+  next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask,
+                                        next_hidden_state_origin_derivative);
+
+  Tensor &cell_state_zoneout_mask =
+    test
+      ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
+      : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
+  cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_cell_state_zoneout_mask =
+    cell_state_zoneout_mask.getBatchSlice(timestep, 1);
+  next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+  Tensor prev_cell_state_zoneout_mask;
+  if (!test) {
+    prev_cell_state_zoneout_mask =
+      next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+  } else {
+    next_cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
+    prev_cell_state_zoneout_mask.add_i(1.0f);
   }
-  if (cell_state_zoneout_rate > epsilon) {
-    Tensor &cell_state_zoneout_mask =
-      test
-        ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
-        : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
-    cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_cell_state_zoneout_mask =
-      cell_state_zoneout_mask.getBatchSlice(timestep, 1);
-    next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
-    Tensor prev_cell_state_zoneout_mask;
-    if (!test) {
-      prev_cell_state_zoneout_mask =
-        next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
-    } else {
-      next_cell_state_zoneout_mask.multiply(-1.0f,
-                                            prev_cell_state_zoneout_mask);
-      prev_cell_state_zoneout_mask.add_i(1.0f);
-    }
 
-    if (timestep) {
-      prev_cell_state_derivative =
-        cell_state_derivative.getBatchSlice(timestep - 1, 1);
-      prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
-      next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask,
-                                          prev_cell_state_derivative_residual);
-    }
+  if (timestep) {
+    prev_cell_state_derivative =
+      cell_state_derivative.getBatchSlice(timestep - 1, 1);
+    prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+    next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask,
+                                        prev_cell_state_derivative_residual);
+  }
 
-    Tensor &cell_state_origin_derivative =
-      context.getTensorGrad(cell_state_origin_idx);
-    cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_cell_state_origin_derivative =
-      cell_state_origin_derivative.getBatchSlice(timestep, 1);
-    next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+  Tensor &cell_state_origin_derivative =
+    context.getTensorGrad(cell_state_origin_idx);
+  cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_cell_state_origin_derivative =
+    cell_state_origin_derivative.getBatchSlice(timestep, 1);
+  next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit});
 
-    next_cell_state_derivative.multiply(next_cell_state_zoneout_mask,
-                                        next_cell_state_origin_derivative);
-  }
+  next_cell_state_derivative.multiply(next_cell_state_zoneout_mask,
+                                      next_cell_state_origin_derivative);
 
   init_lstm_context::fillWeights(weights, context, true, max_timestep, timestep,
                                  test);
@@ -574,22 +560,13 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
   lstmcellcorelayer.calcGradient(core_context);
 
   if (timestep) {
-    if (hidden_state_zoneout_rate > epsilon) {
-      prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual);
-    }
-    if (cell_state_zoneout_rate > epsilon) {
-      prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual);
-    }
+    prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual);
+    prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual);
   }
 }
 
 void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
                                     unsigned int batch) {
-  const float hidden_state_zoneout_rate =
-    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
-  const float cell_state_zoneout_rate =
-    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
-  const bool test = std::get<Test>(zoneout_lstmcell_props);
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props);
   context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state],
@@ -600,14 +577,10 @@ void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
   context.updateTensor(cell_state_origin_idx, max_timestep * batch);
   context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], max_timestep * batch);
 
-  if (hidden_state_zoneout_rate > epsilon && !test) {
-    context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
-                         max_timestep * batch);
-  }
-  if (cell_state_zoneout_rate > epsilon && !test) {
-    context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask],
-                         max_timestep * batch);
-  }
+  context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
+                       max_timestep * batch);
+  context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask],
+                       max_timestep * batch);
 }
 
 } // namespace nntrainer