[zoneout lstmcell] Implement zoneout lstm cell
authorhyeonseok lee <hs89.lee@samsung.com>
Tue, 30 Nov 2021 18:51:18 +0000 (03:51 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 6 Dec 2021 12:35:08 +0000 (21:35 +0900)
 - Zoneout lstmcell is based on the paper and the github repo
   which is mentioned in paper.
 - Todo: Zoneout at inference time is not implemented yet.

refer: https://arxiv.org/pdf/1606.01305.pdf
       https://github.com/teganmaharaj/zoneout

Self evaluation:

    Build test: [X]Passed [ ]Failed [ ]Skipped
    Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
api/ccapi/include/layer.h
jni/Android.mk
nntrainer/app_context.cpp
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/layers/lstmcell.cpp
nntrainer/layers/lstmcell_core.cpp
nntrainer/layers/meson.build
nntrainer/layers/zoneout_lstmcell.cpp [new file with mode: 0644]
nntrainer/layers/zoneout_lstmcell.h [new file with mode: 0644]
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h

index d49ca6f..56b6ac3 100644 (file)
@@ -71,6 +71,7 @@ enum LayerType {
   LAYER_RESHAPE,                           /**< Reshape Layer type */
   LAYER_RNNCELL,                           /**< RNN Cell Layer type */
   LAYER_LSTMCELL,                          /**< LSTM Cell Layer type */
+  LAYER_ZONEOUT_LSTMCELL,                  /**< Zoneout LSTM Cell Layer type */
   LAYER_GRUCELL,                           /**< GRU Cell Layer type */
   LAYER_REDUCE_MEAN,                       /**< Reduce mean Layer type */
   LAYER_LOSS_MSE = 500,             /**< Mean Squared Error Loss Layer type */
@@ -339,6 +340,14 @@ LSTMCell(const std::vector<std::string> &properties = {}) {
 }
 
 /**
+ * @brief Helper function to create ZoneoutLSTMCell layer
+ */
+inline std::unique_ptr<Layer>
+ZoneoutLSTMCell(const std::vector<std::string> &properties = {}) {
+  return createLayer(LayerType::LAYER_ZONEOUT_LSTMCELL, properties);
+}
+
+/**
  * @brief Helper function to create GRU layer
  */
 inline std::unique_ptr<Layer>
index 997bee3..bb99159 100644 (file)
@@ -175,6 +175,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/lstm.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/lstmcell.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/lstmcell_core.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/zoneout_lstmcell.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/gru.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/grucell.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/time_dist.cpp \
index 66c930f..2e44a9a 100644 (file)
@@ -66,6 +66,7 @@
 #include <rnncell.h>
 #include <split_layer.h>
 #include <time_dist.h>
+#include <zoneout_lstmcell.h>
 
 #ifdef ENABLE_TFLITE_BACKBONE
 #include <tflite_layer.h>
@@ -251,6 +252,9 @@ static void add_default_object(AppContext &ac) {
                      LayerType::LAYER_LSTM);
   ac.registerFactory(nntrainer::createLayer<LSTMCellLayer>, LSTMCellLayer::type,
                      LayerType::LAYER_LSTMCELL);
+  ac.registerFactory(nntrainer::createLayer<ZoneoutLSTMCellLayer>,
+                     ZoneoutLSTMCellLayer::type,
+                     LayerType::LAYER_ZONEOUT_LSTMCELL);
   ac.registerFactory(nntrainer::createLayer<SplitLayer>, SplitLayer::type,
                      LayerType::LAYER_SPLIT);
   ac.registerFactory(nntrainer::createLayer<GRULayer>, GRULayer::type,
index 21c1b9e..4d866c6 100644 (file)
@@ -24,6 +24,7 @@
 #include <remap_realizer.h>
 #include <rnncell.h>
 #include <util_func.h>
+#include <zoneout_lstmcell.h>
 
 namespace nntrainer {
 
@@ -134,6 +135,7 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
            node->getType() == LSTMLayer::type ||
            node->getType() == LSTMCellLayer::type ||
            node->getType() == LSTMCellCoreLayer::type ||
+           node->getType() == ZoneoutLSTMCellLayer::type ||
            node->getType() == GRUCellLayer::type;
   };
 
index c9905a1..e2c4244 100644 (file)
@@ -6,6 +6,8 @@
  * @date   17 March 2021
  * @brief  This is LSTMCell Layer Class of Neural Network
  * @see    https://github.com/nnstreamer/nntrainer
+ *         https://arxiv.org/pdf/1606.01305.pdf
+ *         https://github.com/teganmaharaj/zoneout
  * @author Parichay Kapoor <pk.kapoor@samsung.com>
  * @bug    No known bugs except for NYI items
  *
index 3930cc6..5e73c7d 100644 (file)
@@ -125,7 +125,7 @@ void fillInputs(std::vector<Var_Grad> &inputs, RunLayerContext &context,
     }
   }
 
-  inputs[0] = Var_Grad(input, outgoing_derivative);
+  inputs[0] = Var_Grad(input, outgoing_derivative, "lstmcell_core input");
   inputs[1] = Var_Grad(prev_hidden_state, prev_hidden_state_derivative,
                        context.getTensorName(wt_idx[1]));
   inputs[2] = Var_Grad(prev_cell_state, prev_cell_state_derivative,
@@ -220,7 +220,6 @@ void fillTensors(std::vector<Var_Grad> &tensors, RunLayerContext &context,
   }
   tensors[0] =
     Var_Grad(ifgo_t, ifgo_derivative_t, context.getTensorName(wt_idx[0]));
-  context.getTensorName(wt_idx[0]);
 #endif
 }
 
index e863dbc..b3cf7ce 100644 (file)
@@ -27,6 +27,7 @@ layer_sources = [
   'lstm.cpp',
   'lstmcell.cpp',
   'lstmcell_core.cpp',
+  'zoneout_lstmcell.cpp',
   'time_dist.cpp',
   'common_properties.cpp',
   'split_layer.cpp',
diff --git a/nntrainer/layers/zoneout_lstmcell.cpp b/nntrainer/layers/zoneout_lstmcell.cpp
new file mode 100644 (file)
index 0000000..921b1e5
--- /dev/null
@@ -0,0 +1,613 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+ *
+ * @file   zoneout_lstmcell.cpp
+ * @date   30 November 2021
+ * @brief  This is ZoneoutLSTMCell Layer Class of Neural Network
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author hyeonseok lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#include <layer_context.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+#include <zoneout_lstmcell.h>
+
+namespace nntrainer {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+enum ZoneoutLSTMParams {
+  weight_ih,
+  weight_hh,
+  bias_ih,
+  hidden_state,
+  cell_state,
+  ifgo,
+  hidden_state_zoneout_mask,
+  cell_state_zoneout_mask,
+};
+
+unsigned int hidden_state_origin_idx = 0, cell_state_origin_idx = 0;
+
+const std::vector<unsigned int>
+getInputIdx(std::array<unsigned int, 8> &wt_idx) {
+  std::vector<unsigned int> ret(3);
+  ret[0] = SINGLE_INOUT_IDX;
+  ret[1] = wt_idx[ZoneoutLSTMParams::hidden_state];
+  ret[2] = wt_idx[ZoneoutLSTMParams::cell_state];
+  return ret;
+}
+
+const std::vector<unsigned int>
+getOutputIdx(std::array<unsigned int, 8> &wt_idx) {
+  std::vector<unsigned int> ret(3);
+  ret[0] = SINGLE_INOUT_IDX;
+  ret[1] = hidden_state_origin_idx;
+  ret[2] = cell_state_origin_idx;
+  return ret;
+}
+
+const std::vector<unsigned int>
+getTensorIdx(std::array<unsigned int, 8> &wt_idx) {
+  std::vector<unsigned int> ret(1);
+  ret[0] = wt_idx[ZoneoutLSTMParams::ifgo];
+  return ret;
+}
+
+ZoneoutLSTMCellLayer::ZoneoutLSTMCellLayer() :
+  LayerImpl(),
+  zoneout_lstmcell_props(props::Unit(), HiddenStateZoneOutRate(),
+                         CellStateZoneOutRate(), Test(), props::MaxTimestep(),
+                         props::Timestep()),
+  wt_idx({0}),
+  epsilon(1e-3) {}
+
+bool ZoneoutLSTMCellLayer::HiddenStateZoneOutRate::isValid(
+  const float &value) const {
+  if (value < 0.0f || value > 1.0f) {
+    return false;
+  } else {
+    return true;
+  }
+}
+
+bool ZoneoutLSTMCellLayer::CellStateZoneOutRate::isValid(
+  const float &value) const {
+  if (value < 0.0f || value > 1.0f) {
+    return false;
+  } else {
+    return true;
+  }
+}
+
+void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
+  NNTR_THROW_IF(std::get<props::Unit>(zoneout_lstmcell_props).empty(),
+                std::invalid_argument)
+    << "unit property missing for zoneout_lstmcell layer";
+  const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
+  const float hidden_state_zoneout_rate =
+    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
+  const float cell_state_zoneout_rate =
+    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
+  const bool test = std::get<Test>(zoneout_lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+
+#if !ENABLE_SHARING_WT_IDX
+  const Tensor::Initializer weight_initializer =
+    std::get<props::WeightInitializer>(*layer_impl_props);
+  const Tensor::Initializer bias_initializer =
+    std::get<props::BiasInitializer>(*layer_impl_props);
+  const nntrainer::WeightRegularizer weight_regularizer =
+    std::get<props::WeightRegularizer>(*layer_impl_props);
+  const float weight_regularizer_constant =
+    std::get<props::WeightRegularizerConstant>(*layer_impl_props);
+#endif
+
+  if (context.getNumInputs() != 1)
+    throw std::invalid_argument("ZoneoutLSTMCellLayer takes only one input");
+  if (std::get<props::MaxTimestep>(zoneout_lstmcell_props).empty())
+    throw std::invalid_argument("Number of unroll steps(max timestep) must be "
+                                "provided to zoneout LSTM cells");
+  if (std::get<props::Timestep>(zoneout_lstmcell_props).empty())
+    throw std::invalid_argument(
+      "Current timestep must be provided to zoneout LSTM cell");
+
+  // input_dim = [ batch_size, 1, 1, feature_size ]
+  const TensorDim &input_dim = context.getInputDimensions()[0];
+  if (input_dim.height() != 1 || input_dim.channel() != 1)
+    throw std::invalid_argument("Input must be single time dimension for "
+                                "ZoneoutLSTMCell (shape should be "
+                                "[batch_size, 1, 1, feature_size])");
+  const unsigned int batch_size = input_dim.batch();
+#if !ENABLE_SHARING_WT_IDX
+  const unsigned int feature_size = input_dim.width();
+#endif
+
+  // output_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim output_dim(batch_size, 1, 1, unit);
+  context.setOutputDimensions({output_dim});
+
+#if !ENABLE_SHARING_WT_IDX
+  // weight_initializer can be set seperately. weight_ih initializer,
+  // weight_hh initializer kernel initializer & recurrent_initializer in keras
+  // for now, it is set same way.
+
+  // - weight_ih ( input to hidden )
+  //  : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o
+  TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
+  wt_idx[ZoneoutLSTMParams::weight_ih] =
+    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_ih", true);
+  // - weight_hh ( hidden to hidden )
+  //  : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o
+  TensorDim weight_hh_dim({unit, NUM_GATE * unit});
+  wt_idx[ZoneoutLSTMParams::weight_hh] =
+    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_hh", true);
+  // - bias_ih ( input bias )
+  //  : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
+  TensorDim bias_ih_dim({NUM_GATE * unit});
+  wt_idx[ZoneoutLSTMParams::bias_ih] =
+    context.requestWeight(bias_ih_dim, bias_initializer,
+                          WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+#endif
+
+  // hidden_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
+  const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
+                                                unit);
+  if (test) {
+    wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
+      context.requestWeight(hidden_state_zoneout_mask_dim,
+                            Tensor::Initializer::NONE, WeightRegularizer::NONE,
+                            1.0f, "hidden_state_zoneout_mask", false);
+  } else if (hidden_state_zoneout_rate > epsilon) {
+    wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
+      context.requestTensor(
+        hidden_state_zoneout_mask_dim, "hidden_state_zoneout_mask",
+        Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
+  }
+  // cell_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
+  const TensorDim cell_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
+                                              unit);
+  if (test) {
+    wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestWeight(
+      cell_state_zoneout_mask_dim, Tensor::Initializer::NONE,
+      WeightRegularizer::NONE, 1.0f, "cell_state_zoneout_mask", false);
+  } else if (cell_state_zoneout_rate > epsilon) {
+    wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
+      cell_state_zoneout_mask_dim, "cell_state_zoneout_mask",
+      Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
+  }
+
+  /**
+   * TODO: hidden_state is only used from the previous timestep. Once it is
+   * supported as input, no need to cache the hidden_state itself
+   */
+  /** hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
+  const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
+  wt_idx[ZoneoutLSTMParams::hidden_state] = context.requestTensor(
+    hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN, false);
+  /** cell_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
+  const TensorDim cell_state_dim(max_timestep * batch_size, 1, 1, unit);
+  wt_idx[ZoneoutLSTMParams::cell_state] = context.requestTensor(
+    cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN, false);
+
+  hidden_state_origin_idx = context.requestTensor(
+    hidden_state_dim, "hidden_state_origin", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN, false);
+  cell_state_origin_idx = context.requestTensor(
+    cell_state_dim, "cell_state_origin", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN, false);
+
+#if !ENABLE_SHARING_WT_IDX
+  /**
+   * TODO: make this independent of time dimension once recurrent realizer
+   * supports requesting tensors which are not always shared
+   */
+  /** ifgo_dim = [ max_timestep * batch_size, 1, 1, NUM_GATE * unit ] */
+  const TensorDim ifgo_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit);
+  wt_idx[ZoneoutLSTMParams::ifgo] =
+    context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
+                          TensorLifespan::ITERATION_LIFESPAN, false);
+#endif
+
+  TensorDim hidden_state_t_dim({batch_size, 1, 1, unit});
+  TensorDim cell_state_t_dim({batch_size, 1, 1, unit});
+  InitLayerContext core_context(
+    {input_dim, hidden_state_t_dim, cell_state_t_dim}, 3,
+    context.executeInPlace(), context.getName());
+  lstmcellcorelayer.finalize(core_context);
+  init_lstm_context::fillLayerInitContext(context, core_context);
+}
+
+void ZoneoutLSTMCellLayer::setProperty(const std::vector<std::string> &values) {
+  std::vector<std::string> remain_props =
+    loadProperties(values, zoneout_lstmcell_props);
+
+  // Note: In current implementation the lstmcellcorelayer also has
+  // a properties related to weight. But it is not exported or used anywhere.
+  lstmcellcorelayer.setProperty(remain_props);
+  if (!std::get<props::Unit>(zoneout_lstmcell_props).empty()) {
+    lstmcellcorelayer.setProperty(
+      {"unit=" + to_string(std::get<props::Unit>(zoneout_lstmcell_props))});
+  }
+
+#if !ENABLE_SHARING_WT_IDX
+  // To remove lstmcell core layer's properties
+  std::tuple<props::HiddenStateActivation, props::RecurrentActivation>
+    lstmcell_core_props;
+  std::vector<std::string> impl_props =
+    loadProperties(remain_props, lstmcell_core_props);
+
+  LayerImpl::setProperty(impl_props);
+#endif
+}
+
+void ZoneoutLSTMCellLayer::exportTo(Exporter &exporter,
+                                    const ExportMethods &method) const {
+#if !ENABLE_SHARING_WT_IDX
+  LayerImpl::exportTo(exporter, method);
+#endif
+  exporter.saveResult(
+    std::forward_as_tuple(
+      std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props),
+      std::get<CellStateZoneOutRate>(zoneout_lstmcell_props),
+      std::get<Test>(zoneout_lstmcell_props),
+      std::get<props::MaxTimestep>(zoneout_lstmcell_props),
+      std::get<props::Timestep>(zoneout_lstmcell_props)),
+    method, this);
+  lstmcellcorelayer.exportTo(exporter, method);
+}
+
+void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
+  const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
+  const float hidden_state_zoneout_rate =
+    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
+  const float cell_state_zoneout_rate =
+    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
+  const bool test = std::get<Test>(zoneout_lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+  const unsigned int timestep =
+    std::get<props::Timestep>(zoneout_lstmcell_props);
+
+  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+  const TensorDim &input_dim = input.getDim();
+  const unsigned int batch_size = input_dim.batch();
+
+  Tensor &hidden_state =
+    context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
+  hidden_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_hidden_state;
+  if (!timestep) {
+    prev_hidden_state = Tensor(batch_size, 1, 1, unit);
+    prev_hidden_state.setZero();
+  } else {
+    prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1);
+    prev_hidden_state.reshape({batch_size, 1, 1, unit});
+  }
+  Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
+  next_hidden_state.reshape({batch_size, 1, 1, unit});
+
+  Tensor &cell_state = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
+  cell_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_cell_state;
+  if (!timestep) {
+    prev_cell_state = Tensor(batch_size, 1, 1, unit);
+    prev_cell_state.setZero();
+  } else {
+    prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1);
+    prev_cell_state.reshape({batch_size, 1, 1, unit});
+  }
+  Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1);
+  next_cell_state.reshape({batch_size, 1, 1, unit});
+
+  if (!timestep) {
+    hidden_state.setZero();
+    cell_state.setZero();
+  }
+
+  init_lstm_context::fillWeights(weights, context, training, max_timestep,
+                                 timestep, test);
+  init_lstm_context::fillInputs(inputs, context, training, getInputIdx(wt_idx),
+                                max_timestep, timestep);
+  init_lstm_context::fillOutputs(outputs, context, training,
+                                 getOutputIdx(wt_idx), max_timestep, timestep);
+  init_lstm_context::fillTensors(tensors, context, training,
+                                 getTensorIdx(wt_idx), max_timestep, timestep);
+  RunLayerContext core_context(context.getName(), context.getTrainable(),
+                               context.getLoss(), context.executeInPlace(),
+                               init_lstm_context::getWeights(weights),
+                               init_lstm_context::getInputs(inputs),
+                               init_lstm_context::getOutputs(outputs),
+                               init_lstm_context::getTensors(tensors));
+  lstmcellcorelayer.forwarding(core_context, training);
+
+  if (hidden_state_zoneout_rate > epsilon) {
+    if (training) {
+      Tensor &hidden_state_zoneout_mask =
+        test ? context.getWeight(
+                 wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
+             : context.getTensor(
+                 wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
+      hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+      Tensor next_hidden_state_zoneout_mask =
+        hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
+      next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+      Tensor prev_hidden_state_zoneout_mask;
+      if (!test) {
+        prev_hidden_state_zoneout_mask =
+          next_hidden_state_zoneout_mask.zoneout_mask(
+            hidden_state_zoneout_rate);
+      } else {
+        next_hidden_state_zoneout_mask.multiply(-1.0f,
+                                                prev_hidden_state_zoneout_mask);
+        prev_hidden_state_zoneout_mask.add_i(1.0f);
+      }
+
+      Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx);
+      hidden_state_origin.reshape({max_timestep, 1, batch_size, unit});
+      Tensor next_hidden_state_origin =
+        hidden_state_origin.getBatchSlice(timestep, 1);
+      next_hidden_state_origin.reshape({batch_size, 1, 1, unit});
+
+      next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask,
+                                        next_hidden_state);
+      prev_hidden_state.multiply(prev_hidden_state_zoneout_mask,
+                                 next_hidden_state, 1.0f);
+    }
+    // Todo: zoneout at inference
+  }
+  if (cell_state_zoneout_rate > epsilon) {
+    if (training) {
+      Tensor &cell_state_zoneout_mask =
+        test ? context.getWeight(
+                 wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
+             : context.getTensor(
+                 wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
+      cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+      Tensor next_cell_state_zoneout_mask =
+        cell_state_zoneout_mask.getBatchSlice(timestep, 1);
+      next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+      Tensor prev_cell_state_zoneout_mask;
+      if (!test) {
+        prev_cell_state_zoneout_mask =
+          next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+      } else {
+        next_cell_state_zoneout_mask.multiply(-1.0f,
+                                              prev_cell_state_zoneout_mask);
+        prev_cell_state_zoneout_mask.add_i(1.0f);
+      }
+
+      Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx);
+      cell_state_origin.reshape({max_timestep, 1, batch_size, unit});
+      Tensor next_cell_state_origin =
+        cell_state_origin.getBatchSlice(timestep, 1);
+      next_cell_state_origin.reshape({batch_size, 1, 1, unit});
+
+      next_cell_state_origin.multiply(next_cell_state_zoneout_mask,
+                                      next_cell_state);
+      prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state,
+                               1.0f);
+    }
+    // Todo: zoneout at inference
+  }
+
+  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+  output.copyData(next_hidden_state);
+}
+
+void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
+  const bool test = std::get<Test>(zoneout_lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+  const unsigned int timestep =
+    std::get<props::Timestep>(zoneout_lstmcell_props);
+
+  init_lstm_context::fillWeights(weights, context, true, max_timestep, timestep,
+                                 test);
+  init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx),
+                                max_timestep, timestep);
+  init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx),
+                                 max_timestep, timestep);
+  init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
+                                 max_timestep, timestep);
+  RunLayerContext core_context(context.getName(), context.getTrainable(),
+                               context.getLoss(), context.executeInPlace(),
+                               init_lstm_context::getWeights(weights),
+                               init_lstm_context::getInputs(inputs),
+                               init_lstm_context::getOutputs(outputs),
+                               init_lstm_context::getTensors(tensors));
+  lstmcellcorelayer.calcDerivative(core_context);
+}
+
+void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
+  const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
+  const float hidden_state_zoneout_rate =
+    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
+  const float cell_state_zoneout_rate =
+    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
+  const bool test = std::get<Test>(zoneout_lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+  const unsigned int timestep =
+    std::get<props::Timestep>(zoneout_lstmcell_props);
+
+  unsigned int batch_size = context.getInput(SINGLE_INOUT_IDX).getDim().batch();
+
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+  Tensor &hidden_state_derivative =
+    context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]);
+  hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_hidden_state_derivative =
+    hidden_state_derivative.getBatchSlice(timestep, 1);
+  next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+
+  Tensor &cell_state_derivative =
+    context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]);
+  cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_cell_state_derivative =
+    cell_state_derivative.getBatchSlice(timestep, 1);
+  next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+
+  if (timestep + 1 == max_timestep) {
+    Tensor &djdweight_ih =
+      context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]);
+    Tensor &djdweight_hh =
+      context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]);
+    Tensor &djdbias_ih =
+      context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih]);
+    djdweight_ih.setZero();
+    djdweight_hh.setZero();
+    djdbias_ih.setZero();
+
+    hidden_state_derivative.setZero();
+    cell_state_derivative.setZero();
+  }
+
+  next_hidden_state_derivative.add_i(incoming_derivative);
+
+  Tensor prev_hidden_state_derivative;
+  Tensor prev_cell_state_derivative;
+  Tensor prev_hidden_state_derivative_residual;
+  Tensor prev_cell_state_derivative_residual;
+  if (hidden_state_zoneout_rate > epsilon) {
+    Tensor &hidden_state_zoneout_mask =
+      test ? context.getWeight(
+               wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
+           : context.getTensor(
+               wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
+    hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_hidden_state_zoneout_mask =
+      hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
+    next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+    Tensor prev_hidden_state_zoneout_mask;
+    if (!test) {
+      prev_hidden_state_zoneout_mask =
+        next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+    } else {
+      next_hidden_state_zoneout_mask.multiply(-1.0f,
+                                              prev_hidden_state_zoneout_mask);
+      prev_hidden_state_zoneout_mask.add_i(1.0f);
+    }
+
+    if (timestep) {
+      prev_hidden_state_derivative =
+        hidden_state_derivative.getBatchSlice(timestep - 1, 1);
+      prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+      next_hidden_state_derivative.multiply(
+        prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual);
+    }
+
+    Tensor &hidden_state_origin_derivative =
+      context.getTensorGrad(hidden_state_origin_idx);
+    hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_hidden_state_origin_derivative =
+      hidden_state_origin_derivative.getBatchSlice(timestep, 1);
+    next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+
+    next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask,
+                                          next_hidden_state_origin_derivative);
+  }
+  if (cell_state_zoneout_rate > epsilon) {
+    Tensor &cell_state_zoneout_mask =
+      test
+        ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
+        : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
+    cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_cell_state_zoneout_mask =
+      cell_state_zoneout_mask.getBatchSlice(timestep, 1);
+    next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+    Tensor prev_cell_state_zoneout_mask;
+    if (!test) {
+      prev_cell_state_zoneout_mask =
+        next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+    } else {
+      next_cell_state_zoneout_mask.multiply(-1.0f,
+                                            prev_cell_state_zoneout_mask);
+      prev_cell_state_zoneout_mask.add_i(1.0f);
+    }
+
+    if (timestep) {
+      prev_cell_state_derivative =
+        cell_state_derivative.getBatchSlice(timestep - 1, 1);
+      prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+      next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask,
+                                          prev_cell_state_derivative_residual);
+    }
+
+    Tensor &cell_state_origin_derivative =
+      context.getTensorGrad(cell_state_origin_idx);
+    cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
+    Tensor next_cell_state_origin_derivative =
+      cell_state_origin_derivative.getBatchSlice(timestep, 1);
+    next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+
+    next_cell_state_derivative.multiply(next_cell_state_zoneout_mask,
+                                        next_cell_state_origin_derivative);
+  }
+
+  init_lstm_context::fillWeights(weights, context, true, max_timestep, timestep,
+                                 test);
+  init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx),
+                                max_timestep, timestep);
+  init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx),
+                                 max_timestep, timestep);
+  init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
+                                 max_timestep, timestep);
+  RunLayerContext core_context(context.getName(), context.getTrainable(),
+                               context.getLoss(), context.executeInPlace(),
+                               init_lstm_context::getWeights(weights),
+                               init_lstm_context::getInputs(inputs),
+                               init_lstm_context::getOutputs(outputs),
+                               init_lstm_context::getTensors(tensors));
+  lstmcellcorelayer.calcGradient(core_context);
+
+  if (timestep) {
+    if (hidden_state_zoneout_rate > epsilon) {
+      prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual);
+    }
+    if (cell_state_zoneout_rate > epsilon) {
+      prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual);
+    }
+  }
+}
+
+void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
+                                    unsigned int batch) {
+  const float hidden_state_zoneout_rate =
+    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
+  const float cell_state_zoneout_rate =
+    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
+  const bool test = std::get<Test>(zoneout_lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+  context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state],
+                       max_timestep * batch);
+  context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state],
+                       max_timestep * batch);
+  context.updateTensor(hidden_state_origin_idx, max_timestep * batch);
+  context.updateTensor(cell_state_origin_idx, max_timestep * batch);
+  context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], max_timestep * batch);
+
+  if (hidden_state_zoneout_rate > epsilon && !test) {
+    context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
+                         max_timestep * batch);
+  }
+  if (cell_state_zoneout_rate > epsilon && !test) {
+    context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask],
+                         max_timestep * batch);
+  }
+}
+
+} // namespace nntrainer
diff --git a/nntrainer/layers/zoneout_lstmcell.h b/nntrainer/layers/zoneout_lstmcell.h
new file mode 100644 (file)
index 0000000..19f9c3b
--- /dev/null
@@ -0,0 +1,201 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+ *
+ * @file   zoneout_lstmcell.h
+ * @date   30 November 2021
+ * @brief  This is ZoneoutLSTMCell Layer Class of Neural Network
+ * @see           https://github.com/nnstreamer/nntrainer
+ * @author hyeonseok lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#ifndef __ZONEOUTLSTMCELL_H__
+#define __ZONEOUTLSTMCELL_H__
+#ifdef __cplusplus
+
+#include <acti_func.h>
+#include <common_properties.h>
+#include <layer_impl.h>
+#include <lstmcell_core.h>
+
+namespace nntrainer {
+
+/**
+ * @class   ZoneoutLSTMCellLayer
+ * @brief   ZoneoutLSTMCellLayer
+ */
+class ZoneoutLSTMCellLayer : public LayerImpl {
+public:
+  /**
+   * @brief HiddenStateZoneOutRate property, this defines zone out rate for
+   * hidden state
+   *
+   */
+  class HiddenStateZoneOutRate : public nntrainer::Property<float> {
+
+  public:
+    /**
+     * @brief Construct a new HiddenStateZoneOutRate object with a default value
+     * 0.0
+     *
+     */
+    HiddenStateZoneOutRate(float value = 0.0) :
+      nntrainer::Property<float>(value) {}
+    static constexpr const char *key =
+      "hidden_state_zoneout_rate";   /**< unique key to access */
+    using prop_tag = float_prop_tag; /**< property type */
+
+    /**
+     * @brief HiddenStateZoneOutRate validator
+     *
+     * @param v float to validate
+     * @retval true if it is equal or greater than 0.0 and equal or smaller than
+     * to 1.0
+     * @retval false if it is samller than 0.0 or greater than 1.0
+     */
+    bool isValid(const float &value) const override;
+  };
+
+  /**
+   * @brief CellStateZoneOutRate property, this defines zone out rate for cell
+   * state
+   *
+   */
+  class CellStateZoneOutRate : public nntrainer::Property<float> {
+
+  public:
+    /**
+     * @brief Construct a new CellStateZoneOutRate object with a default value
+     * 0.0
+     *
+     */
+    CellStateZoneOutRate(float value = 0.0) :
+      nntrainer::Property<float>(value) {}
+    static constexpr const char *key =
+      "cell_state_zoneout_rate";     /**< unique key to access */
+    using prop_tag = float_prop_tag; /**< property type */
+
+    /**
+     * @brief CellStateZoneOutRate validator
+     *
+     * @param v float to validate
+     * @retval true if it is equal or greater than 0.0 and equal or smaller than
+     * to 1.0
+     * @retval false if it is samller than 0.0 or greater than 1.0
+     */
+    bool isValid(const float &value) const override;
+  };
+
+  /**
+   * @brief Test property, this property is set to true when test the zoneout
+   * lstmcell in unittest
+   *
+   */
+  class Test : public nntrainer::Property<bool> {
+
+  public:
+    /**
+     * @brief Construct a new Test object with a default value false
+     *
+     */
+    Test(bool value = false) : nntrainer::Property<bool>(value) {}
+    static constexpr const char *key = "test"; /**< unique key to access */
+    using prop_tag = bool_prop_tag;            /**< property type */
+  };
+
+  /**
+   * @brief     Constructor of ZoneoutLSTMCellLayer
+   */
+  ZoneoutLSTMCellLayer();
+
+  /**
+   * @brief     Destructor of ZoneoutLSTMCellLayer
+   */
+  ~ZoneoutLSTMCellLayer() = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::calcGradient(RunLayerContext &context)
+   */
+  void calcGradient(RunLayerContext &context) override;
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
+   */
+  void exportTo(Exporter &exporter, const ExportMethods &method) const override;
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override {
+    return ZoneoutLSTMCellLayer::type;
+  };
+
+  /**
+   * @copydoc Layer::supportBackwarding()
+   */
+  bool supportBackwarding() const override { return true; }
+
+  /**
+   * @copydoc Layer::setProperty(const PropertyType type, const std::string
+   * &value)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
+   */
+  void setBatch(RunLayerContext &context, unsigned int batch) override;
+
+  inline static const std::string type = "zoneout_lstmcell";
+
+private:
+  static constexpr unsigned int NUM_GATE = 4;
+
+  LSTMCellCoreLayer lstmcellcorelayer;
+
+  /**
+   * Unit: number of output neurons
+   * HiddenStateZoneOutRate: zoneout rate for hidden_state
+   * CellStateZoneOutRate: zoneout rate for cell_state
+   * Test: property for test mode
+   * MaxTimestep: maximum timestep for zoneout lstmcell
+   * TimeStep: timestep for which lstm should operate
+   *
+   * */
+  std::tuple<props::Unit, HiddenStateZoneOutRate, CellStateZoneOutRate, Test,
+             props::MaxTimestep, props::Timestep>
+    zoneout_lstmcell_props;
+  std::array<unsigned int, 8> wt_idx; /**< indices of the weights */
+
+  /**
+   * @brief     Protect overflow
+   */
+  float epsilon;
+
+  // These weights, inputs, outputs, tensors are all for the lstm_core
+  // Todo: remove this
+  std::vector<Weight> weights;
+  std::vector<Var_Grad> inputs;
+  std::vector<Var_Grad> outputs;
+  std::vector<Var_Grad> tensors;
+};
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __ZONEOUTLSTMCELL_H__ */
index d0cf702..e13b4e8 100644 (file)
@@ -120,8 +120,7 @@ public:
   SrcSharedTensor() : src(nullptr), off(0) {}
 
   SrcSharedTensor(const Tensor *tensor, unsigned int offset) :
-    src(tensor),
-    off(offset) {}
+    src(tensor), off(offset) {}
 
   /**
    * @brief   Get the allocated src tensor
@@ -252,6 +251,11 @@ void Tensor::setRandUniform(float min, float max) {
     std::uniform_real_distribution<float>(min, max));
 }
 
+void Tensor::setRandBernoulli(float probability) {
+  setDist<std::bernoulli_distribution>(
+    std::bernoulli_distribution(probability));
+}
+
 void Tensor::initialize() {
   if (empty() || !isAllocated())
     return;
@@ -1248,6 +1252,26 @@ void Tensor::filter_mask(const Tensor &mask_len, bool reverse) {
   }
 }
 
+Tensor Tensor::zoneout_mask(float zoneout) {
+  Tensor ret(getDim());
+  zoneout_mask(ret, zoneout);
+  return ret;
+}
+
+void Tensor::zoneout_mask(Tensor &opposite, float zoneout) {
+  opposite.setRandBernoulli(zoneout);
+  float *data = getData();
+  float *opposite_data = opposite.getData();
+
+  for (unsigned int i = 0; i < size(); ++i) {
+    if (opposite_data[i] > epsilon) {
+      data[i] = 0.0f;
+    } else {
+      data[i] = 1.0f;
+    }
+  }
+}
+
 int Tensor::apply_i(std::function<float(float)> f) {
   Tensor result = *this;
   apply(f, result);
index 499ef6b..0cf1c4e 100644 (file)
@@ -705,6 +705,26 @@ public:
   void filter_mask(const Tensor &mask_len, bool reverse = false);
 
   /**
+   * @brief Calculate 2 Zone Out Mask
+   * @details Calculate zone out mask according to the bernoulli distribution.
+   * Zone out mask with rate @a zoneout for inplace and the other zone out mask
+   * with rate @a (1-zoneout).
+   * @param zoneout zone out rate
+   * @retval Tensor zone out mask for opposite tensor
+   */
+  Tensor zoneout_mask(float zoneout);
+
+  /**
+   * @brief Calculate 2 Zone Out Mask
+   * @details Calculate zone out mask according to the bernoulli distribution.
+   * Zone out mask with rate @a zoneout for inplace and the other zone out mask
+   * with rate @a (1-zoneout).
+   * @param opposite opposite zone out mask
+   * @param zoneout zone out rate
+   */
+  void zoneout_mask(Tensor &opposite, float zoneout);
+
+  /**
    * @brief     sum all the Tensor elements according to the batch
    * @retval    Calculated Tensor(batch, 1, 1, 1)
    */
@@ -974,6 +994,12 @@ public:
   void setRandUniform(float min = -0.05f, float max = 0.05f);
 
   /**
+   * @brief     Set the tensor with random bernoulli distribution
+   * @param[in] probability probability value for the distribution
+   */
+  void setRandBernoulli(float probability = 0.5f);
+
+  /**
    * @brief     Initialize the memory of the given tensor
    */
   void initialize();