[lstmcell] Refactoring the lstmcell
authorhyeonseok lee <hs89.lee@samsung.com>
Tue, 30 Nov 2021 11:55:54 +0000 (20:55 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 6 Dec 2021 12:35:08 +0000 (21:35 +0900)
 - Refactoring the lstmcell to lstm core layer.
   This lstm core layer will be used in zoneout lstmcell layer
 - lstm core layer is designed to have 3 inputs, 2 outputs
   like other framework.

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
jni/Android.mk
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/layers/lstmcell.cpp
nntrainer/layers/lstmcell.h
nntrainer/layers/lstmcell_core.cpp [new file with mode: 0644]
nntrainer/layers/lstmcell_core.h [new file with mode: 0644]
nntrainer/layers/meson.build

index 9db746c..997bee3 100644 (file)
@@ -174,6 +174,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/rnncell.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/lstm.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/lstmcell.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/lstmcell_core.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/gru.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/grucell.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/time_dist.cpp \
index 02df5ab..21c1b9e 100644 (file)
@@ -18,6 +18,7 @@
 #include <layer_node.h>
 #include <lstm.h>
 #include <lstmcell.h>
+#include <lstmcell_core.h>
 #include <nntrainer_error.h>
 #include <node_exporter.h>
 #include <remap_realizer.h>
@@ -132,6 +133,7 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
     return node->getType() == RNNCellLayer::type ||
            node->getType() == LSTMLayer::type ||
            node->getType() == LSTMCellLayer::type ||
+           node->getType() == LSTMCellCoreLayer::type ||
            node->getType() == GRUCellLayer::type;
   };
 
index 4daf12e..c9905a1 100644 (file)
  *
  */
 
-#include <cmath>
 #include <layer_context.h>
 #include <lstmcell.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <node_exporter.h>
-#include <util_func.h>
 
 namespace nntrainer {
 
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
-enum LSTMParams {
-  weight_xh,
+enum LSTMCellParams {
+  weight_ih,
   weight_hh,
-  bias_h,
+  bias_ih,
   hidden_state,
-  mem_cell,
-  fgio,
+  cell_state,
+  ifgo,
   dropout_mask
 };
 
+const std::vector<unsigned int>
+getInOutIdx(std::array<unsigned int, 7> &wt_idx) {
+  std::vector<unsigned int> ret(3);
+  ret[0] = SINGLE_INOUT_IDX;
+  ret[1] = wt_idx[LSTMCellParams::hidden_state];
+  ret[2] = wt_idx[LSTMCellParams::cell_state];
+  return ret;
+}
+
+const std::vector<unsigned int>
+getTensorIdx(std::array<unsigned int, 7> &wt_idx) {
+  std::vector<unsigned int> ret(1);
+  ret[0] = wt_idx[LSTMCellParams::ifgo];
+  return ret;
+}
+
 LSTMCellLayer::LSTMCellLayer() :
   LayerImpl(),
-  lstm_props(props::Unit(), props::HiddenStateActivation(),
-             props::RecurrentActivation(), props::DropOutRate(),
-             props::MaxTimestep(), props::Timestep()),
+  lstmcell_props(props::Unit(), props::DropOutRate(), props::MaxTimestep(),
+                 props::Timestep()),
   wt_idx({0}),
-  acti_func(ActivationType::ACT_NONE, true),
-  recurrent_acti_func(ActivationType::ACT_NONE, true),
   epsilon(1e-3) {}
 
-// - weight_xh ( input to hidden )
-//  : [1, 1, input_size, unit (hidden_size) x NUM_GATE] -> f, g, i, o
-// - weight_hh ( hidden to hidden )
-//  : [1, 1, unit (hidden_size) , unit (hidden_size) x NUM_GATE] -> f, g, i, o
-// - bias_h ( hidden bias )
-//  : [1, 1, 1, unit (hidden_size) x NUM_GATE] -> f, g, i, o
 void LSTMCellLayer::finalize(InitLayerContext &context) {
-  auto &weight_regularizer =
+  NNTR_THROW_IF(std::get<props::Unit>(lstmcell_props).empty(),
+                std::invalid_argument)
+    << "unit property missing for lstmcell layer";
+  const unsigned int unit = std::get<props::Unit>(lstmcell_props).get();
+  const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(lstmcell_props);
+
+#if !ENABLE_SHARING_WT_IDX
+  const Tensor::Initializer weight_initializer =
+    std::get<props::WeightInitializer>(*layer_impl_props);
+  const Tensor::Initializer bias_initializer =
+    std::get<props::BiasInitializer>(*layer_impl_props);
+  const nntrainer::WeightRegularizer weight_regularizer =
     std::get<props::WeightRegularizer>(*layer_impl_props);
-  auto &weight_regularizer_constant =
+  const float weight_regularizer_constant =
     std::get<props::WeightRegularizerConstant>(*layer_impl_props);
-  auto &weight_initializer =
-    std::get<props::WeightInitializer>(*layer_impl_props);
-  auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
-
-  NNTR_THROW_IF(std::get<props::Unit>(lstm_props).empty(),
-                std::invalid_argument)
-    << "unit property missing for lstm layer";
-  auto unit = std::get<props::Unit>(lstm_props).get();
-  auto &hidden_state_activation_type =
-    std::get<props::HiddenStateActivation>(lstm_props);
-  auto &recurrent_activation_type =
-    std::get<props::RecurrentActivation>(lstm_props);
-  float dropout_rate = std::get<props::DropOutRate>(lstm_props);
+#endif
 
   if (context.getNumInputs() != 1)
-    throw std::invalid_argument("LSTM layer takes only one input");
-  if (std::get<props::MaxTimestep>(lstm_props).empty())
+    throw std::invalid_argument("LSTMCell layer takes only one input");
+  if (std::get<props::MaxTimestep>(lstmcell_props).empty())
     throw std::invalid_argument(
-      "Number of unroll steps must be provided to LSTM cells");
-  if (std::get<props::Timestep>(lstm_props).empty())
+      "Number of unroll steps(max timestep) must be provided to LSTM cell");
+  if (std::get<props::Timestep>(lstmcell_props).empty())
     throw std::invalid_argument(
       "Current Timestep must be provided to LSTM cell");
 
-  // input_dim = [ batch, 1, 1, feature_size ]
-  TensorDim output_dim;
+  // input_dim = [ batch_size, 1, 1, feature_size ]
   const TensorDim &input_dim = context.getInputDimensions()[0];
-  if (input_dim.height() != 1 && input_dim.channel() != 1)
+  if (input_dim.height() != 1 || input_dim.channel() != 1)
     throw std::invalid_argument(
-      "Input must be single time dimension for LSTMCell");
-  // output_dim = [ batch, 1, 1, hidden_size (unit)]
-  output_dim = input_dim;
-  output_dim.width(unit);
-
-  if (dropout_rate > epsilon) {
-    wt_idx[LSTMParams::dropout_mask] = context.requestTensor(
-      output_dim, "dropout_mask", Tensor::Initializer::NONE, false,
-      TensorLifespan::ITERATION_LIFESPAN);
-  }
-
+      "Input must be single time dimension for LSTMCell (shape should be "
+      "[batch_size, 1, 1, feature_size])");
+  const unsigned int batch_size = input_dim.batch();
+#if !ENABLE_SHARING_WT_IDX
+  const unsigned int feature_size = input_dim.width();
+#endif
+
+  // output_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim output_dim(batch_size, 1, 1, unit);
   context.setOutputDimensions({output_dim});
 
-  TensorDim bias_dim = TensorDim();
-  bias_dim.setTensorDim(3, unit * NUM_GATE);
-
-  TensorDim dim_xh = output_dim;
-  dim_xh.height(input_dim.width());
-  dim_xh.width(unit * NUM_GATE);
-  dim_xh.batch(1);
-
-  TensorDim dim_hh = output_dim;
-  dim_hh.height(unit);
-  dim_hh.width(unit * NUM_GATE);
-  dim_hh.batch(1);
-
-  // weight_initializer can be set seperately. weight_xh initializer,
+#if !ENABLE_SHARING_WT_IDX
+  // weight_initializer can be set seperately. weight_ih initializer,
   // weight_hh initializer kernel initializer & recurrent_initializer in keras
   // for now, it is set same way.
-  wt_idx[LSTMParams::weight_xh] =
-    context.requestWeight(dim_xh, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_xh", true);
-  wt_idx[LSTMParams::weight_hh] =
-    context.requestWeight(dim_hh, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
-  wt_idx[LSTMParams::bias_h] = context.requestWeight(
-    bias_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, "bias_h", true);
-
-  unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
-
-  TensorDim d = input_dim;
-  // d.height(d.batch());
-  d.height(1);
-  d.batch(max_timestep * d.batch());
-  d.width(unit);
 
-  /** hidden dim = [ UnrollLength, 1, Batch, Units ] */
-  wt_idx[LSTMParams::hidden_state] =
-    context.requestTensor(d, "hidden_state", Tensor::Initializer::NONE, true,
-                          TensorLifespan::ITERATION_LIFESPAN, false);
-  wt_idx[LSTMParams::mem_cell] =
-    context.requestTensor(d, "mem_cell", Tensor::Initializer::NONE, true,
-                          TensorLifespan::ITERATION_LIFESPAN, false);
+  // - weight_ih ( input to hidden )
+  //  : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o
+  TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
+  wt_idx[LSTMCellParams::weight_ih] =
+    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_ih", true);
+  // - weight_hh ( hidden to hidden )
+  //  : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o
+  TensorDim weight_hh_dim({unit, NUM_GATE * unit});
+  wt_idx[LSTMCellParams::weight_hh] =
+    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_hh", true);
+  // - bias_ih ( input bias )
+  //  : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
+  TensorDim bias_ih_dim({NUM_GATE * unit});
+  wt_idx[LSTMCellParams::bias_ih] =
+    context.requestWeight(bias_ih_dim, bias_initializer,
+                          WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+#endif
+
+  // dropout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
+  const TensorDim dropout_mask_dim(max_timestep * batch_size, 1, 1, unit);
+  if (dropout_rate > epsilon) {
+    wt_idx[LSTMCellParams::dropout_mask] = context.requestTensor(
+      dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
+      TensorLifespan::ITERATION_LIFESPAN);
+  }
 
   /**
+   * TODO: hidden_state is only used from the previous timestep. Once it is
+   * supported as input, no need to cache the hidden_state itself
+   */
+  /** hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
+  const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
+  wt_idx[LSTMCellParams::hidden_state] = context.requestTensor(
+    hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN, false);
+  /** cell_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
+  const TensorDim cell_state_dim(max_timestep * batch_size, 1, 1, unit);
+  wt_idx[LSTMCellParams::cell_state] = context.requestTensor(
+    cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN, false);
+#if !ENABLE_SHARING_WT_IDX
+  /**
    * TODO: make this independent of time dimension once recurrent realizer
    * supports requesting tensors which are not always shared
-   *
-   * TODO: reorder to ifgo for better performance. This will require change in
-   * stored weights in the test
    */
-  d.width(unit * NUM_GATE);
-  wt_idx[LSTMParams::fgio] =
-    context.requestTensor(d, "fgio", Tensor::Initializer::NONE, true,
-                          TensorLifespan::ITERATION_LIFESPAN, false);
+  /** ifgo_dim = [ max_timestep * batch_size, 1, 1, NUM_GATE * unit ] */
+  const TensorDim ifgo_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit);
+  wt_idx[LSTMCellParams::ifgo] =
+    context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
+                          TensorLifespan::ITERATION_LIFESPAN);
+#endif
+
+  TensorDim hidden_state_t_dim({batch_size, 1, 1, unit});
+  TensorDim cell_state_t_dim({batch_size, 1, 1, unit});
+  InitLayerContext core_context(
+    {input_dim, hidden_state_t_dim, cell_state_t_dim}, 3,
+    context.executeInPlace(), context.getName());
+  lstmcellcorelayer.finalize(core_context);
+  init_lstm_context::fillLayerInitContext(context, core_context);
+}
 
-  if (hidden_state_activation_type.get() == ActivationType::ACT_NONE) {
-    hidden_state_activation_type.set(ActivationType::ACT_TANH);
+void LSTMCellLayer::setProperty(const std::vector<std::string> &values) {
+  std::vector<std::string> remain_props =
+    loadProperties(values, lstmcell_props);
+
+  // Note: In current implementation the lstmcellcorelayer also has
+  // a properties related to weight. But it is not exported or used anywhere.
+  lstmcellcorelayer.setProperty(remain_props);
+  if (!std::get<props::Unit>(lstmcell_props).empty()) {
+    lstmcellcorelayer.setProperty(
+      {"unit=" + to_string(std::get<props::Unit>(lstmcell_props))});
   }
-  acti_func.setActiFunc(hidden_state_activation_type.get());
 
-  if (recurrent_activation_type.get() == ActivationType::ACT_NONE) {
-    recurrent_activation_type.set(ActivationType::ACT_SIGMOID);
-  }
-  recurrent_acti_func.setActiFunc(recurrent_activation_type.get());
-}
+#if !ENABLE_SHARING_WT_IDX
+  // To remove lstmcell core layer's properties
+  std::tuple<props::HiddenStateActivation, props::RecurrentActivation>
+    lstmcell_core_props;
+  std::vector<std::string> impl_props =
+    loadProperties(remain_props, lstmcell_core_props);
 
-void LSTMCellLayer::setProperty(const std::vector<std::string> &values) {
-  auto remain_props = loadProperties(values, lstm_props);
-  LayerImpl::setProperty(remain_props);
+  LayerImpl::setProperty(impl_props);
+#endif
 }
 
 void LSTMCellLayer::exportTo(Exporter &exporter,
                              const ExportMethods &method) const {
+#if !ENABLE_SHARING_WT_IDX
   LayerImpl::exportTo(exporter, method);
-  exporter.saveResult(lstm_props, method, this);
+#endif
+  exporter.saveResult(
+    std::forward_as_tuple(std::get<props::DropOutRate>(lstmcell_props),
+                          std::get<props::MaxTimestep>(lstmcell_props),
+                          std::get<props::Timestep>(lstmcell_props)),
+    method, this);
+  lstmcellcorelayer.exportTo(exporter, method);
 }
 
 void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
-  auto unit = std::get<props::Unit>(lstm_props).get();
-  float dropout_rate = std::get<props::DropOutRate>(lstm_props);
-
-  Tensor &weight_xh = context.getWeight(wt_idx[LSTMParams::weight_xh]);
-  Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
-  Tensor &bias_h = context.getWeight(wt_idx[LSTMParams::bias_h]);
-
-  Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
-  Tensor &hidden_ = context.getTensor(wt_idx[LSTMParams::hidden_state]);
-  Tensor &cell_ = context.getTensor(wt_idx[LSTMParams::mem_cell]);
-  Tensor &fgio = context.getTensor(wt_idx[LSTMParams::fgio]);
-  const TensorDim &input_dim = input_.getDim();
-  unsigned int batch = input_dim.batch();
-
-  unsigned int start_timestep = std::get<props::Timestep>(lstm_props);
-
-  if (start_timestep == 0) {
-    hidden_.setZero();
-    cell_.setZero();
-  }
-
-  unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
-  hidden_.reshape({max_timestep, 1, batch, hidden_.width()});
-  cell_.reshape({max_timestep, 1, batch, cell_.width()});
-  fgio.reshape({max_timestep, 1, batch, fgio.width()});
-
-  /**
-   * @note when the recurrent realization happens, different instances of lstm
-   * will share the weights, hidden state, cell and fgio memory. However, they
-   * do not share the input, output and derivatives memory. The input/output
-   * will be contain a single timestep data only.
-   */
-  Tensor hs = hidden_.getBatchSlice(start_timestep, 1);
-  Tensor cs = cell_.getBatchSlice(start_timestep, 1);
-  Tensor fgio_t = fgio.getBatchSlice(start_timestep, 1);
-
-  input_.dot(weight_xh, fgio_t);
-
-  if (start_timestep > 0) {
-    Tensor hs_prev = hidden_.getBatchSlice(start_timestep - 1, 1);
-    hs_prev.dot(weight_hh, fgio_t, false, false, 1.0);
-  }
-
-  fgio_t.add_i(bias_h);
-  Tensor hif = fgio_t.getSharedDataTensor({batch, unit * 2}, 0, false);
-  Tensor hi = fgio_t.getSharedDataTensor({batch, unit}, 0, false);
-  Tensor hf = fgio_t.getSharedDataTensor({batch, unit}, unit, false);
-  Tensor hg = fgio_t.getSharedDataTensor({batch, unit}, unit * 2, false);
-  Tensor ho = fgio_t.getSharedDataTensor({batch, unit}, unit * 3, false);
-
-  recurrent_acti_func.run_fn(hif, hif);
-  recurrent_acti_func.run_fn(ho, ho);
-  acti_func.run_fn(hg, hg);
-
-  if (start_timestep > 0) {
-    Tensor cs_prev = cell_.getBatchSlice(start_timestep - 1, 1);
-    hf.multiply_strided(cs_prev, cs);
+  const unsigned int unit = std::get<props::Unit>(lstmcell_props).get();
+  const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(lstmcell_props);
+  const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
+
+  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+  const TensorDim &input_dim = input.getDim();
+  const unsigned int batch_size = input_dim.batch();
+
+  Tensor &hidden_state =
+    context.getTensor(wt_idx[LSTMCellParams::hidden_state]);
+  hidden_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
+  next_hidden_state.reshape({batch_size, 1, 1, unit});
+
+  Tensor &cell_state = context.getTensor(wt_idx[LSTMCellParams::cell_state]);
+
+  if (!timestep) {
+    hidden_state.setZero();
+    cell_state.setZero();
   }
-  hg.multiply_strided(hi, cs, 1.0);
 
-  acti_func.run_fn(cs, hs);
-  hs.multiply_i_strided(ho);
+  init_lstm_context::fillWeights(weights, context, training, max_timestep,
+                                 timestep);
+  init_lstm_context::fillInputs(inputs, context, training, getInOutIdx(wt_idx),
+                                max_timestep, timestep);
+  init_lstm_context::fillOutputs(outputs, context, training,
+                                 getInOutIdx(wt_idx), max_timestep, timestep);
+  init_lstm_context::fillTensors(tensors, context, training,
+                                 getTensorIdx(wt_idx), max_timestep, timestep);
+  RunLayerContext core_context(context.getName(), context.getTrainable(),
+                               context.getLoss(), context.executeInPlace(),
+                               init_lstm_context::getWeights(weights),
+                               init_lstm_context::getInputs(inputs),
+                               init_lstm_context::getOutputs(outputs),
+                               init_lstm_context::getTensors(tensors));
+  lstmcellcorelayer.forwarding(core_context, training);
 
   if (dropout_rate > epsilon && training) {
-    Tensor &mask_ = context.getTensor(wt_idx[LSTMParams::dropout_mask]);
-    hs.dropout_mask(dropout_rate);
-    hs.multiply_i(mask_);
+    Tensor &dropout_mask =
+      context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
+    dropout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor dropout_mask_t = dropout_mask.getBatchSlice(timestep, 1);
+    dropout_mask_t.dropout_mask(dropout_rate);
+    next_hidden_state.multiply_i(dropout_mask_t);
   }
 
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
-  std::copy(hs.getData(), hs.getData() + hs.size(), output.getData());
+  output.copyData(next_hidden_state);
 }
 
 void LSTMCellLayer::calcDerivative(RunLayerContext &context) {
-  Tensor &derivative_ = context.getTensorGrad(wt_idx[LSTMParams::fgio]);
-  Tensor &weight = context.getWeight(wt_idx[LSTMParams::weight_xh]);
-  Tensor &ret_ = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
-
-  unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
-  derivative_.reshape({max_timestep, 1, ret_.batch(), derivative_.width()});
-
-  /** get the timestep values */
-  unsigned int start_timestep = std::get<props::Timestep>(lstm_props);
-
-  Tensor deriv_t = derivative_.getBatchSlice(start_timestep, 1);
-  deriv_t.dot(weight, ret_, false, true);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(lstmcell_props);
+  const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
+
+  init_lstm_context::fillWeights(weights, context, true, max_timestep,
+                                 timestep);
+  init_lstm_context::fillInputs(inputs, context, true, getInOutIdx(wt_idx),
+                                max_timestep, timestep);
+  init_lstm_context::fillOutputs(outputs, context, true, getInOutIdx(wt_idx),
+                                 max_timestep, timestep);
+  init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
+                                 max_timestep, timestep);
+  RunLayerContext core_context(context.getName(), context.getTrainable(),
+                               context.getLoss(), context.executeInPlace(),
+                               init_lstm_context::getWeights(weights),
+                               init_lstm_context::getInputs(inputs),
+                               init_lstm_context::getOutputs(outputs),
+                               init_lstm_context::getTensors(tensors));
+  lstmcellcorelayer.calcDerivative(core_context);
 }
 
 void LSTMCellLayer::calcGradient(RunLayerContext &context) {
-  auto unit = std::get<props::Unit>(lstm_props).get();
-  float dropout_rate = std::get<props::DropOutRate>(lstm_props);
-
-  Tensor &djdw_x = context.getWeightGrad(wt_idx[LSTMParams::weight_xh]);
-  Tensor &djdw_h = context.getWeightGrad(wt_idx[LSTMParams::weight_hh]);
-  Tensor &djdb_h = context.getWeightGrad(wt_idx[LSTMParams::bias_h]);
-  Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
-
-  Tensor &derivative_ = context.getTensorGrad(wt_idx[LSTMParams::hidden_state]);
-  /**
-   * TODO: hidden_ is only used from the previous timestep. Once it is supported
-   * as input, no need to cache the hidden_ itself
-   */
-  Tensor &hidden_ = context.getTensor(wt_idx[LSTMParams::hidden_state]);
-  Tensor &incoming_deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
-  Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
-  Tensor &m_cell_ = context.getTensor(wt_idx[LSTMParams::mem_cell]);
-  Tensor &dm_cell_ = context.getTensorGrad(wt_idx[LSTMParams::mem_cell]);
-  Tensor &fgio = context.getTensor(wt_idx[LSTMParams::fgio]);
-  Tensor &d_fgio = context.getTensorGrad(wt_idx[LSTMParams::fgio]);
-  const TensorDim &input_dim = input_.getDim();
-  unsigned int batch = input_dim.batch();
-
-  /** get the timestep values */
-  unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
-  unsigned int start_timestep = std::get<props::Timestep>(lstm_props);
-
-  derivative_.reshape({max_timestep, 1, batch, derivative_.width()});
-  hidden_.reshape({max_timestep, 1, batch, hidden_.width()});
-  m_cell_.reshape({max_timestep, 1, batch, m_cell_.width()});
-  dm_cell_.reshape({max_timestep, 1, batch, dm_cell_.width()});
-  fgio.reshape({max_timestep, 1, batch, fgio.width()});
-  d_fgio.reshape({max_timestep, 1, batch, d_fgio.width()});
-
-  if (start_timestep + 1 == max_timestep) {
-    djdw_x.setZero();
-    djdw_h.setZero();
-    djdb_h.setZero();
-  }
-
-  Tensor dh = derivative_.getBatchSlice(start_timestep, 1);
-  dh.reshape(incoming_deriv.getDim());
-  if (start_timestep + 1 == max_timestep) {
-    dh.copyData(incoming_deriv);
-  } else {
-    dh.add_i(incoming_deriv);
+  const unsigned int unit = std::get<props::Unit>(lstmcell_props).get();
+  const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(lstmcell_props);
+  const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
+
+  unsigned int batch_size = context.getInput(SINGLE_INOUT_IDX).getDim().batch();
+
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+  Tensor &hidden_state_derivative =
+    context.getTensorGrad(wt_idx[LSTMCellParams::hidden_state]);
+  hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_hidden_state_derivative =
+    hidden_state_derivative.getBatchSlice(timestep, 1);
+  next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+
+  Tensor &cell_state_derivative =
+    context.getTensorGrad(wt_idx[LSTMCellParams::cell_state]);
+  cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_cell_state_derivative =
+    cell_state_derivative.getBatchSlice(timestep, 1);
+  next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+
+  if (timestep + 1 == max_timestep) {
+    Tensor &djdweight_ih =
+      context.getWeightGrad(wt_idx[LSTMCellParams::weight_ih]);
+    Tensor &djdweight_hh =
+      context.getWeightGrad(wt_idx[LSTMCellParams::weight_hh]);
+    Tensor &djdbias_ih = context.getWeightGrad(wt_idx[LSTMCellParams::bias_ih]);
+    djdweight_ih.setZero();
+    djdweight_hh.setZero();
+    djdbias_ih.setZero();
+
+    next_hidden_state_derivative.setZero();
+    next_cell_state_derivative.setZero();
   }
-  dh = derivative_.getBatchSlice(start_timestep, 1);
 
   if (dropout_rate > epsilon) {
-    derivative_.multiply_i(context.getTensor(wt_idx[LSTMParams::dropout_mask]));
+    Tensor &dropout_mask =
+      context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
+    dropout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor dropout_mask_t = dropout_mask.getBatchSlice(timestep, 1);
+    next_hidden_state_derivative.multiply_i(dropout_mask_t);
   }
 
-  Tensor dc = dm_cell_.getBatchSlice(start_timestep, 1);
-  Tensor xs = input_;
-  Tensor hs_t = hidden_.getBatchSlice(start_timestep, 1);
-  Tensor cs = m_cell_.getBatchSlice(start_timestep, 1);
-
-  Tensor dfgio_t = d_fgio.getBatchSlice(start_timestep, 1);
-  Tensor fgio_t = fgio.getBatchSlice(start_timestep, 1);
-
-  Tensor dhif = dfgio_t.getSharedDataTensor({batch, unit * 2}, 0, false);
-  Tensor dhi = dfgio_t.getSharedDataTensor({batch, unit}, 0, false);
-  Tensor dhf = dfgio_t.getSharedDataTensor({batch, unit}, unit, false);
-  Tensor dhg = dfgio_t.getSharedDataTensor({batch, unit}, unit * 2, false);
-  Tensor dho = dfgio_t.getSharedDataTensor({batch, unit}, unit * 3, false);
-
-  Tensor hif = fgio_t.getSharedDataTensor({batch, unit * 2}, 0, false);
-  Tensor hi = fgio_t.getSharedDataTensor({batch, unit}, 0, false);
-  Tensor hf = fgio_t.getSharedDataTensor({batch, unit}, unit, false);
-  Tensor hg = fgio_t.getSharedDataTensor({batch, unit}, unit * 2, false);
-  Tensor ho = fgio_t.getSharedDataTensor({batch, unit}, unit * 3, false);
-
-  acti_func.run_fn(cs, cs);
-  cs.multiply_strided(dh, dho);
-
-  if (start_timestep + 1 == max_timestep) {
-    acti_func.run_prime_fn(cs, dc, dh);
-    dc.multiply_i_strided(ho);
-  } else {
-    /// @todo optimize this by updating run_prime_fn to accumulate or make
-    /// it inplace somehow
-    Tensor dc_temp(dc.getDim());
-    acti_func.run_prime_fn(cs, dc_temp, dh);
-    dc_temp.multiply_strided(ho, dc, 1.0);
-  }
-
-  if (start_timestep > 0) {
-    Tensor dc_nx = dm_cell_.getBatchSlice(start_timestep - 1, 1);
-    dc.multiply_strided(hf, dc_nx);
-    Tensor cs_prev = m_cell_.getBatchSlice(start_timestep - 1, 1);
-    dc.multiply_strided(cs_prev, dhf);
-  } else {
-    dhf.setZero();
-  }
-
-  dc.multiply_strided(hg, dhi);
-  dc.multiply_strided(hi, dhg);
-
-  recurrent_acti_func.run_prime_fn(ho, dho, dho);
-  recurrent_acti_func.run_prime_fn(hif, dhif, dhif);
-  acti_func.run_prime_fn(hg, dhg, dhg);
-  dfgio_t.sum(2, djdb_h, 1.0, 1.0);
-
-  xs.dot(dfgio_t, djdw_x, true, false, 1.0f);
-  if (start_timestep != 0) {
-    Tensor hs_prev = hidden_.getBatchSlice(start_timestep - 1, 1);
-    hs_prev.dot(dfgio_t, djdw_h, true, false, 1.0f);
-    Tensor dh_nx = derivative_.getBatchSlice(start_timestep - 1, 1);
-    dfgio_t.dot(weight_hh, dh_nx, false, true, 1.0f);
-  }
+  next_hidden_state_derivative.add_i(incoming_derivative);
+
+  init_lstm_context::fillWeights(weights, context, true, max_timestep,
+                                 timestep);
+  init_lstm_context::fillInputs(inputs, context, true, getInOutIdx(wt_idx),
+                                max_timestep, timestep);
+  init_lstm_context::fillOutputs(outputs, context, true, getInOutIdx(wt_idx),
+                                 max_timestep, timestep);
+  init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
+                                 max_timestep, timestep);
+  RunLayerContext core_context(context.getName(), context.getTrainable(),
+                               context.getLoss(), context.executeInPlace(),
+                               init_lstm_context::getWeights(weights),
+                               init_lstm_context::getInputs(inputs),
+                               init_lstm_context::getOutputs(outputs),
+                               init_lstm_context::getTensors(tensors));
+  lstmcellcorelayer.calcGradient(core_context);
 }
 
 void LSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
-  unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
-  context.updateTensor(wt_idx[LSTMParams::hidden_state], batch * max_timestep);
-  context.updateTensor(wt_idx[LSTMParams::mem_cell], batch * max_timestep);
-  context.updateTensor(wt_idx[LSTMParams::fgio], batch * max_timestep);
-
-  const float dropout_rate = std::get<props::DropOutRate>(lstm_props);
+  const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(lstmcell_props);
+  context.updateTensor(wt_idx[LSTMCellParams::hidden_state],
+                       max_timestep * batch);
+  context.updateTensor(wt_idx[LSTMCellParams::cell_state],
+                       max_timestep * batch);
+  context.updateTensor(wt_idx[LSTMCellParams::ifgo], max_timestep * batch);
   if (dropout_rate > epsilon) {
-    context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch);
+    context.updateTensor(wt_idx[LSTMCellParams::dropout_mask],
+                         max_timestep * batch);
   }
 }
 
index cc6dcf4..9fdd408 100644 (file)
@@ -18,6 +18,7 @@
 #include <acti_func.h>
 #include <common_properties.h>
 #include <layer_impl.h>
+#include <lstmcell_core.h>
 
 namespace nntrainer {
 
@@ -87,34 +88,31 @@ public:
 private:
   static constexpr unsigned int NUM_GATE = 4;
 
+  LSTMCellCoreLayer lstmcellcorelayer;
+
   /**
    * Unit: number of output neurons
-   * HiddenStateActivation: activation type for hidden state. default is tanh
-   * RecurrentActivation: activation type for recurrent. default is sigmoid
    * DropOutRate: dropout rate
+   * MaxTimestep: maximum timestep for lstmcell
    * TimeStep: timestep for which lstm should operate
    *
    * */
-  std::tuple<props::Unit, props::HiddenStateActivation,
-             props::RecurrentActivation, props::DropOutRate, props::MaxTimestep,
+  std::tuple<props::Unit, props::DropOutRate, props::MaxTimestep,
              props::Timestep>
-    lstm_props;
+    lstmcell_props;
   std::array<unsigned int, 7> wt_idx; /**< indices of the weights */
 
   /**
-   * @brief     activation function for h_t : default is tanh
-   */
-  ActiFunc acti_func;
-
-  /**
-   * @brief     activation function for recurrent : default is sigmoid
-   */
-  ActiFunc recurrent_acti_func;
-
-  /**
    * @brief     to protect overflow
    */
   float epsilon;
+
+  // These weights, inputs, outputs, tensors are all for the lstm_core
+  // Todo: remove this
+  std::vector<Weight> weights;
+  std::vector<Var_Grad> inputs;
+  std::vector<Var_Grad> outputs;
+  std::vector<Var_Grad> tensors;
 };
 } // namespace nntrainer
 
diff --git a/nntrainer/layers/lstmcell_core.cpp b/nntrainer/layers/lstmcell_core.cpp
new file mode 100644 (file)
index 0000000..3930cc6
--- /dev/null
@@ -0,0 +1,538 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+ *
+ * @file   lstmcell_core.cpp
+ * @date   25 November 2021
+ * @brief  This is LSTMCellCore Layer Class of Neural Network
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author hyeonseok lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#include <layer_context.h>
+#include <lstmcell_core.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+
+// ENABLE_SHARING_WT_IDX implies does the wt_idx of lstm_core can be shared
+// with lstm_cell variant layer.
+// Todo: remove this if sharing wt_idx with other lstm variant is enabled
+#define ENABLE_SHARING_WT_IDX 0
+
+namespace nntrainer {
+
+namespace init_lstm_context {
+void fillLayerInitContext(InitLayerContext &context,
+                          const InitLayerContext &core_context) {
+  /** real set the input flags */
+  auto const &input_dims = context.getInputDimensions();
+  for (unsigned int idx = 0; idx < core_context.getNumInputs(); idx++) {
+    context.setDynDimFlagInputDimension(idx, input_dims[idx].getDynDimFlag());
+    context.setEffDimFlagInputDimension(idx, input_dims[idx].getEffDimFlag());
+  }
+
+  /** real request of tensors */
+  for (auto const &ts : core_context.getTensorsSpec())
+    context.requestTensor(ts);
+
+  /** real request of weights */
+  for (auto const &ws : core_context.getWeightsSpec())
+    context.requestWeight(ws);
+}
+
+void fillWeights(std::vector<Weight> &weights, const RunLayerContext &context,
+                 bool training, const unsigned int max_timestep,
+                 const unsigned int timestep, bool test) {
+  weights.resize(context.getNumWeights());
+  for (unsigned int i = 0; i < context.getNumWeights(); ++i) {
+    if (training && (!test || i < 3u)) {
+      weights[i] = Weight(context.getWeight(i), context.getWeightGrad(i),
+                          context.getWeightName(i));
+    } else {
+      weights[i] =
+        Weight(context.getWeight(i), Tensor(), context.getWeightName(i));
+    }
+  }
+}
+
+const std::vector<Weight *> getWeights(std::vector<Weight> &weights) {
+  std::vector<Weight *> ret(weights.size());
+  for (unsigned int i = 0; i < weights.size(); ++i) {
+    ret[i] = &weights[i];
+  }
+  return ret;
+}
+
+void fillInputs(std::vector<Var_Grad> &inputs, RunLayerContext &context,
+                bool training, const std::vector<unsigned int> &wt_idx,
+                const unsigned int max_timestep, const unsigned int timestep) {
+  inputs.resize(3);
+  Tensor empty;
+  const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
+  const unsigned int batch_size = output_dim.batch();
+  const unsigned int unit = output_dim.width();
+
+  const Tensor &input = context.getInput(wt_idx[0]);
+  const Tensor &outgoing_derivative =
+    training ? context.getOutgoingDerivative(0) : empty;
+
+  Tensor &hidden_state = context.getTensor(wt_idx[1]);
+  hidden_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor &hidden_state_derivative =
+    training ? context.getTensorGrad(wt_idx[1]) : empty;
+  if (training) {
+    hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+  }
+
+  Tensor &cell_state = context.getTensor(wt_idx[2]);
+  cell_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor &cell_state_derivative =
+    training ? context.getTensorGrad(wt_idx[2]) : empty;
+  if (training) {
+    cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+  }
+
+  Tensor prev_hidden_state;
+  Tensor prev_hidden_state_derivative;
+  Tensor prev_cell_state;
+  Tensor prev_cell_state_derivative;
+  if (!timestep) {
+    prev_hidden_state = Tensor(batch_size, 1, 1, unit);
+    prev_hidden_state.setZero();
+    prev_hidden_state_derivative = Tensor(batch_size, 1, 1, unit);
+    prev_hidden_state_derivative.setZero();
+    prev_cell_state = Tensor(batch_size, 1, 1, unit);
+    prev_cell_state.setZero();
+    prev_cell_state_derivative = Tensor(batch_size, 1, 1, unit);
+    prev_cell_state_derivative.setZero();
+  } else {
+    prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1);
+    prev_hidden_state.reshape({batch_size, 1, 1, unit});
+    if (training) {
+      prev_hidden_state_derivative =
+        hidden_state_derivative.getBatchSlice(timestep - 1, 1);
+      prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+    }
+    prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1);
+    prev_cell_state.reshape({batch_size, 1, 1, unit});
+    if (training) {
+      prev_cell_state_derivative =
+        cell_state_derivative.getBatchSlice(timestep - 1, 1);
+      prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+    }
+  }
+
+  inputs[0] = Var_Grad(input, outgoing_derivative);
+  inputs[1] = Var_Grad(prev_hidden_state, prev_hidden_state_derivative,
+                       context.getTensorName(wt_idx[1]));
+  inputs[2] = Var_Grad(prev_cell_state, prev_cell_state_derivative,
+                       context.getTensorName(wt_idx[2]));
+}
+
+const std::vector<Var_Grad *> getInputs(std::vector<Var_Grad> &inputs) {
+  std::vector<Var_Grad *> ret(inputs.size());
+  for (unsigned int i = 0; i < inputs.size(); ++i) {
+    ret[i] = &inputs[i];
+  }
+  return ret;
+}
+
+void fillOutputs(std::vector<Var_Grad> &outputs, RunLayerContext &context,
+                 bool training, const std::vector<unsigned int> &wt_idx,
+                 const unsigned int max_timestep, const unsigned int timestep) {
+  outputs.resize(2);
+  Tensor empty;
+  const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
+  const unsigned int batch_size = output_dim.batch();
+  const unsigned int unit = output_dim.width();
+
+  Tensor &hidden_state = context.getTensor(wt_idx[1]);
+  hidden_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
+  next_hidden_state.reshape({batch_size, 1, 1, unit});
+  Tensor next_hidden_state_derivative;
+  if (training) {
+    Tensor &hidden_state_derivative = context.getTensorGrad(wt_idx[1]);
+    hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+    next_hidden_state_derivative =
+      hidden_state_derivative.getBatchSlice(timestep, 1);
+    next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+  }
+
+  Tensor &cell_state = context.getTensor(wt_idx[2]);
+  cell_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1);
+  next_cell_state.reshape({batch_size, 1, 1, unit});
+  Tensor next_cell_state_derivative;
+  if (training) {
+    Tensor &cell_state_derivative = context.getTensorGrad(wt_idx[2]);
+    cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+    next_cell_state_derivative =
+      cell_state_derivative.getBatchSlice(timestep, 1);
+    next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+  }
+
+  outputs[0] = Var_Grad(next_hidden_state, next_hidden_state_derivative,
+                        context.getTensorName(wt_idx[1]));
+  outputs[1] = Var_Grad(next_cell_state, next_cell_state_derivative,
+                        context.getTensorName(wt_idx[2]));
+}
+
+const std::vector<Var_Grad *> getOutputs(std::vector<Var_Grad> &outputs) {
+  std::vector<Var_Grad *> ret(outputs.size());
+  for (unsigned int i = 0; i < outputs.size(); ++i) {
+    ret[i] = &outputs[i];
+  }
+  return ret;
+}
+
+void fillTensors(std::vector<Var_Grad> &tensors, RunLayerContext &context,
+                 bool training, const std::vector<unsigned int> &wt_idx,
+                 const unsigned int max_timestep, const unsigned int timestep) {
+  tensors.resize(1);
+
+#if ENABLE_SHARING_WT_IDX
+  const Tensor &ifgo = context.getTensor(wt_idx[0]);
+  Tensor empty;
+  const Tensor &ifgo_derivative =
+    training ? context.getTensorGrad(wt_idx[0]) : empty;
+  tensors[0] =
+    Var_Grad(ifgo, ifgo_derivative, context.getTensorName(wt_idx[0]));
+#else
+  const TensorDim &output_dim = context.getOutput(0).getDim();
+  const unsigned int batch_size = output_dim.batch();
+  const unsigned int unit = output_dim.width();
+
+  Tensor &ifgo = context.getTensor(wt_idx[0]);
+  const unsigned int NUM_GATE = ifgo.width() / unit;
+  ifgo.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
+  Tensor ifgo_t = ifgo.getBatchSlice(timestep, 1);
+  ifgo_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
+  Tensor ifgo_derivative_t;
+  if (training) {
+    Tensor &ifgo_derivative = context.getTensorGrad(wt_idx[0]);
+    ifgo_derivative.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
+    ifgo_derivative_t = ifgo_derivative.getBatchSlice(timestep, 1);
+    ifgo_derivative_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
+  }
+  tensors[0] =
+    Var_Grad(ifgo_t, ifgo_derivative_t, context.getTensorName(wt_idx[0]));
+  context.getTensorName(wt_idx[0]);
+#endif
+}
+
+const std::vector<Var_Grad *> getTensors(std::vector<Var_Grad> &tensors) {
+  std::vector<Var_Grad *> ret(tensors.size());
+  for (unsigned int i = 0; i < tensors.size(); ++i) {
+    ret[i] = &tensors[i];
+  }
+  return ret;
+}
+
+} // namespace init_lstm_context
+
+enum INDEX {
+  INPUT = 0,
+  HIDDEN_STATE_IN = 1,
+  CELL_STATE_IN = 2,
+  HIDDEN_STATE_OUT = 0,
+  CELL_STATE_OUT = 1
+};
+
+enum LSTMCellCoreParams {
+  weight_ih,
+  weight_hh,
+  bias_ih,
+  ifgo,
+};
+
+LSTMCellCoreLayer::LSTMCellCoreLayer() :
+  LayerImpl(),
+  lstmcell_core_props(
+    props::Unit(), props::HiddenStateActivation() = ActivationType::ACT_TANH,
+    props::RecurrentActivation() = ActivationType::ACT_SIGMOID),
+  wt_idx({0}),
+  acti_func(ActivationType::ACT_NONE, true),
+  recurrent_acti_func(ActivationType::ACT_NONE, true) {}
+
+void LSTMCellCoreLayer::finalize(InitLayerContext &context) {
+  NNTR_THROW_IF(std::get<props::Unit>(lstmcell_core_props).empty(),
+                std::invalid_argument)
+    << "unit property missing for lstmcell_core layer";
+  const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
+  const nntrainer::props::HiddenStateActivation hidden_state_activation_type =
+    std::get<props::HiddenStateActivation>(lstmcell_core_props);
+  const nntrainer::props::RecurrentActivation recurrent_activation_type =
+    std::get<props::RecurrentActivation>(lstmcell_core_props);
+
+#if ENBABLE_SHARING_WEIGHT
+  const Tensor::Initializer weight_initializer =
+    std::get<props::WeightInitializer>(*layer_impl_props);
+  const Tensor::Initializer bias_initializer =
+    std::get<props::BiasInitializer>(*layer_impl_props);
+  const nntrainer::WeightRegularizer weight_regularizer =
+    std::get<props::WeightRegularizer>(*layer_impl_props);
+  const float weight_regularizer_constant =
+    std::get<props::WeightRegularizerConstant>(*layer_impl_props);
+#endif
+
+  if (context.getNumInputs() != 3)
+    throw std::invalid_argument("LSTMCellCore layer should takes 3 input");
+
+  // input_dim = [ batch, 1, 1, feature_size ]
+  const TensorDim &input_dim = context.getInputDimensions()[0];
+  if (input_dim.height() != 1 || input_dim.channel() != 1)
+    throw std::invalid_argument(
+      "Input must be single time dimension for LSTMCellCore");
+  // input_hidden_state_dim = [ batch, 1, 1, unit ]
+  const TensorDim &input_hidden_state_dim = context.getInputDimensions()[1];
+  if (input_hidden_state_dim.channel() != 1 ||
+      input_hidden_state_dim.height() != 1)
+    throw std::invalid_argument("Input hidden state's dimension should be "
+                                "[batch, 1, 1, unit] for LSTMCellCore");
+  // input_cell_state_dim = [ batch, 1, 1, unit ]
+  const TensorDim &input_cell_state_dim = context.getInputDimensions()[2];
+  if (input_cell_state_dim.channel() != 1 || input_cell_state_dim.height() != 1)
+    throw std::invalid_argument("Input cell state's dimension should be "
+                                "[batch, 1, 1, unit] for LSTMCellCore");
+
+  const unsigned int batch_size = input_dim.batch();
+#if ENABLE_SHARING_WT_IDX
+  const unsigned int feature_size = input_dim.width();
+#endif
+
+  const TensorDim output_dim(batch_size, 1, 1, unit);
+  const TensorDim output_hidden_state_dim = input_hidden_state_dim;
+  const TensorDim output_cell_state_dim = input_cell_state_dim;
+
+  context.setOutputDimensions(
+    {output_dim, output_hidden_state_dim, output_cell_state_dim});
+
+#if ENABLE_SHARING_WT_IDX
+  // weight_initializer can be set seperately. weight_ih initializer,
+  // weight_hh initializer kernel initializer & recurrent_initializer in keras
+  // for now, it is set same way.
+
+  // - weight_ih ( input to hidden )
+  //  : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o
+  TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
+  wt_idx[LSTMCellCoreParams::weight_ih] =
+    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_ih", true);
+  // - weight_hh ( hidden to hidden )
+  //  : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o
+  TensorDim weight_hh_dim({unit, NUM_GATE * unit});
+  wt_idx[LSTMCellCoreParams::weight_hh] =
+    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_hh", true);
+  // - bias_ih ( input bias )
+  //  : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
+  TensorDim bias_ih_dim({NUM_GATE * unit});
+  wt_idx[LSTMCellCoreParams::bias_ih] =
+    context.requestWeight(bias_ih_dim, bias_initializer,
+                          WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+#endif
+
+#if ENABLE_SHARING_WT_IDX
+  TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
+  wt_idx[LSTMCellCoreParams::ifgo] =
+    context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
+                          TensorLifespan::ITERATION_LIFESPAN);
+#endif
+  acti_func.setActiFunc(hidden_state_activation_type.get());
+  recurrent_acti_func.setActiFunc(recurrent_activation_type.get());
+}
+
+void LSTMCellCoreLayer::setProperty(const std::vector<std::string> &values) {
+  std::vector<std::string> remain_props =
+    loadProperties(values, lstmcell_core_props);
+  LayerImpl::setProperty(remain_props);
+}
+
+void LSTMCellCoreLayer::exportTo(Exporter &exporter,
+                                 const ExportMethods &method) const {
+#if ENABLE_SHARING_WT_IDX
+  LayerImpl::exportTo(exporter, method);
+#endif
+  exporter.saveResult(lstmcell_core_props, method, this);
+}
+
+void LSTMCellCoreLayer::forwarding(RunLayerContext &context, bool training) {
+  const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
+
+  const Tensor &input = context.getInput(INDEX::INPUT);
+  const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
+  const Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
+  const TensorDim &input_dim = input.getDim();
+  const unsigned int batch_size = input_dim.batch();
+
+  Tensor &next_hidden_state = context.getOutput(INDEX::HIDDEN_STATE_OUT);
+  Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
+
+#if ENABLE_SHARING_WT_IDX
+  const Tensor &weight_ih =
+    context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
+  const Tensor &weight_hh =
+    context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
+  const Tensor &bias_ih =
+    context.getWeight(wt_idx[LSTMCellCoreParams::bias_ih]);
+#else
+  const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
+  const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
+  const Tensor &bias_ih = context.getWeight(LSTMCellCoreParams::bias_ih);
+#endif
+
+#if ENABLE_SHARING_WT_IDX
+  Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
+#else
+  Tensor &ifgo = context.getTensor(0);
+#endif
+
+  input.dot(weight_ih, ifgo);
+  prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0);
+  ifgo.add_i(bias_ih);
+
+  Tensor input_forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+  Tensor input_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor memory_cell =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
+  Tensor output_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
+
+  recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate);
+  recurrent_acti_func.run_fn(output_gate, output_gate);
+  acti_func.run_fn(memory_cell, memory_cell);
+
+  forget_gate.multiply_strided(prev_cell_state, next_cell_state);
+  memory_cell.multiply_strided(input_gate, next_cell_state, 1.0f);
+
+  acti_func.run_fn(next_cell_state, next_hidden_state);
+  next_hidden_state.multiply_i_strided(output_gate);
+}
+
+void LSTMCellCoreLayer::calcDerivative(RunLayerContext &context) {
+#if ENABLE_SHARING_WT_IDX
+  Tensor &ifgo_derivative =
+    context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
+  const Tensor &weight_ih =
+    context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
+#else
+  const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
+  Tensor &ifgo_derivative = context.getTensorGrad(0);
+#endif
+  Tensor &outgoing_derivative = context.getOutgoingDerivative(INDEX::INPUT);
+
+  ifgo_derivative.dot(weight_ih, outgoing_derivative, false, true);
+}
+
+void LSTMCellCoreLayer::calcGradient(RunLayerContext &context) {
+  const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
+
+  const Tensor &input = context.getInput(INDEX::INPUT);
+  const TensorDim &input_dim = input.getDim();
+  const unsigned int batch_size = input_dim.batch();
+
+#if ENABLE_SHARING_WT_IDX
+  Tensor &djdweight_ih =
+    context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_ih]);
+  const Tensor &weight_hh =
+    context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
+  Tensor &djdweight_hh =
+    context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_hh]);
+  Tensor &djdbias_ih =
+    context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_ih]);
+#else
+  Tensor &djdweight_ih = context.getWeightGrad(LSTMCellCoreParams::weight_ih);
+  const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
+  Tensor &djdweight_hh = context.getWeightGrad(LSTMCellCoreParams::weight_hh);
+  Tensor &djdbias_ih = context.getWeightGrad(LSTMCellCoreParams::bias_ih);
+#endif
+
+  const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
+  Tensor &prev_hidden_state_derivative =
+    context.getOutgoingDerivative(INDEX::HIDDEN_STATE_IN);
+  Tensor &next_hidden_state_derivative =
+    context.getIncomingDerivative(INDEX::HIDDEN_STATE_OUT);
+
+  Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
+  Tensor &prev_cell_state_derivative =
+    context.getOutgoingDerivative(INDEX::CELL_STATE_IN);
+  Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
+  Tensor &next_cell_state_derivative =
+    context.getIncomingDerivative(INDEX::CELL_STATE_OUT);
+
+#if ENABLE_SHARING_WT_IDX
+  Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
+  Tensor &ifgo_derivative =
+    context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
+#else
+  Tensor &ifgo = context.getTensor(0);
+  Tensor &ifgo_derivative = context.getTensorGrad(0);
+#endif
+
+  Tensor input_forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+  Tensor input_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor memory_cell =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
+  Tensor output_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
+
+  Tensor input_forget_gate_derivative =
+    ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+  Tensor input_gate_derivative =
+    ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor forget_gate_derivative =
+    ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor memory_cell_derivative = ifgo_derivative.getSharedDataTensor(
+    {batch_size, 1, 1, unit}, unit * 2, false);
+  Tensor output_gate_derivative = ifgo_derivative.getSharedDataTensor(
+    {batch_size, 1, 1, unit}, unit * 3, false);
+
+  acti_func.run_fn(next_cell_state, next_cell_state);
+  next_hidden_state_derivative.multiply_strided(next_cell_state,
+                                                output_gate_derivative);
+
+  acti_func.run_prime_fn(next_cell_state, prev_cell_state_derivative,
+                         next_hidden_state_derivative);
+  prev_cell_state_derivative.multiply_i_strided(output_gate);
+  prev_cell_state_derivative.add_i(next_cell_state_derivative);
+
+  prev_cell_state_derivative.multiply_strided(input_gate,
+                                              memory_cell_derivative);
+  prev_cell_state_derivative.multiply_strided(memory_cell,
+                                              input_gate_derivative);
+
+  prev_cell_state_derivative.multiply_strided(prev_cell_state,
+                                              forget_gate_derivative);
+  prev_cell_state_derivative.multiply_i_strided(forget_gate);
+
+  recurrent_acti_func.run_prime_fn(output_gate, output_gate_derivative,
+                                   output_gate_derivative);
+  recurrent_acti_func.run_prime_fn(input_forget_gate,
+                                   input_forget_gate_derivative,
+                                   input_forget_gate_derivative);
+  acti_func.run_prime_fn(memory_cell, memory_cell_derivative,
+                         memory_cell_derivative);
+
+  ifgo_derivative.sum(0, djdbias_ih, 1.0f, 1.0f);
+  input.dot(ifgo_derivative, djdweight_ih, true, false, 1.0f);
+  prev_hidden_state.dot(ifgo_derivative, djdweight_hh, true, false, 1.0f);
+  ifgo_derivative.dot(weight_hh, prev_hidden_state_derivative, false, true);
+}
+
+void LSTMCellCoreLayer::setBatch(RunLayerContext &context, unsigned int batch) {
+  context.updateTensor(wt_idx[LSTMCellCoreParams::ifgo], batch);
+}
+
+} // namespace nntrainer
diff --git a/nntrainer/layers/lstmcell_core.h b/nntrainer/layers/lstmcell_core.h
new file mode 100644 (file)
index 0000000..45f5bb9
--- /dev/null
@@ -0,0 +1,137 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+ *
+ * @file   lstmcell_core.h
+ * @date   25 November 2021
+ * @brief  This is LSTMCellCore Layer Class of Neural Network
+ * @see           https://github.com/nnstreamer/nntrainer
+ * @author hyeonseok lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#ifndef __LSTMCELLCORE_H__
+#define __LSTMCELLCORE_H__
+#ifdef __cplusplus
+
+#include <acti_func.h>
+#include <common_properties.h>
+#include <layer_impl.h>
+
+namespace nntrainer {
+
+namespace init_lstm_context {
+void fillLayerInitContext(InitLayerContext &context,
+                          const InitLayerContext &core_context);
+void fillWeights(std::vector<Weight> &weights, const RunLayerContext &context,
+                 bool training, const unsigned int max_timestep,
+                 const unsigned int timestep, bool test = false);
+const std::vector<Weight *> getWeights(std::vector<Weight> &weights);
+void fillInputs(std::vector<Var_Grad> &inputs, RunLayerContext &context,
+                bool training, const std::vector<unsigned int> &wt_idx,
+                const unsigned int max_timestep, const unsigned int timestep);
+const std::vector<Var_Grad *> getInputs(std::vector<Var_Grad> &inputs);
+void fillOutputs(std::vector<Var_Grad> &outputs, RunLayerContext &context,
+                 bool training, const std::vector<unsigned int> &wt_idx,
+                 const unsigned int max_timestep, const unsigned int timestep);
+const std::vector<Var_Grad *> getOutputs(std::vector<Var_Grad> &outputs);
+void fillTensors(std::vector<Var_Grad> &tensors, RunLayerContext &context,
+                 bool training, const std::vector<unsigned int> &wt_idx,
+                 const unsigned int max_timestep, const unsigned int timestep);
+const std::vector<Var_Grad *> getTensors(std::vector<Var_Grad> &tensors);
+} // namespace init_lstm_context
+
+/**
+ * @class   LSTMCellCoreLayer
+ * @brief   LSTMCellCoreLayer
+ */
+class LSTMCellCoreLayer : public LayerImpl {
+public:
+  /**
+   * @brief     Constructor of LSTMCellLayer
+   */
+  LSTMCellCoreLayer();
+
+  /**
+   * @brief     Destructor of LSTMCellLayer
+   */
+  ~LSTMCellCoreLayer() = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::calcGradient(RunLayerContext &context)
+   */
+  void calcGradient(RunLayerContext &context) override;
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
+   */
+  void exportTo(Exporter &exporter, const ExportMethods &method) const override;
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override {
+    return LSTMCellCoreLayer::type;
+  };
+
+  /**
+   * @copydoc Layer::supportBackwarding()
+   */
+  bool supportBackwarding() const override { return true; }
+
+  /**
+   * @copydoc Layer::setProperty(const PropertyType type, const std::string
+   * &value)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
+   */
+  void setBatch(RunLayerContext &context, unsigned int batch) override;
+
+  inline static const std::string type = "lstmcell_core";
+
+private:
+  static constexpr unsigned int NUM_GATE = 4;
+
+  /**
+   * Unit: number of output neurons
+   * HiddenStateActivation: activation type for hidden state. default is tanh
+   * RecurrentActivation: activation type for recurrent. default is sigmoid
+   *
+   * */
+  std::tuple<props::Unit, props::HiddenStateActivation,
+             props::RecurrentActivation>
+    lstmcell_core_props;
+  std::array<unsigned int, 4> wt_idx; /**< indices of the weights */
+
+  /**
+   * @brief     activation function for h_t : default is tanh
+   */
+  ActiFunc acti_func;
+
+  /**
+   * @brief     activation function for recurrent : default is sigmoid
+   */
+  ActiFunc recurrent_acti_func;
+};
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __LSTMCELLCORE_H__ */
index 3d3f2c8..e863dbc 100644 (file)
@@ -26,6 +26,7 @@ layer_sources = [
   'acti_func.cpp',
   'lstm.cpp',
   'lstmcell.cpp',
+  'lstmcell_core.cpp',
   'time_dist.cpp',
   'common_properties.cpp',
   'split_layer.cpp',