[zoneout lstmcell] refactoring zoneout lstmcell layer
authorhyeonseok lee <hs89.lee@samsung.com>
Sat, 18 Dec 2021 01:21:29 +0000 (10:21 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 29 Dec 2021 06:20:14 +0000 (15:20 +0900)
 - Refactoring zoneout lstmcell layer to use lstmcore functions.
 - Preserve lstm_cell_state tensor for calcGradient.
 - Remove lstmcell core layer

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/layers/lstmcell_core.cpp
nntrainer/layers/lstmcell_core.h
nntrainer/layers/zoneout_lstmcell.cpp
nntrainer/layers/zoneout_lstmcell.h

index cf9c116..d750100 100644 (file)
@@ -23,7 +23,6 @@
 #include <layer_node.h>
 #include <lstm.h>
 #include <lstmcell.h>
-#include <lstmcell_core.h>
 #include <nntrainer_error.h>
 #include <node_exporter.h>
 #include <recurrent_realizer.h>
@@ -180,7 +179,6 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
     return node->getType() == RNNCellLayer::type ||
            node->getType() == LSTMLayer::type ||
            node->getType() == LSTMCellLayer::type ||
-           node->getType() == LSTMCellCoreLayer::type ||
            node->getType() == ZoneoutLSTMCellLayer::type ||
            node->getType() == GRUCellLayer::type;
   };
index 9bf6fe7..c6d31bd 100644 (file)
  *
  * @file   lstmcell_core.cpp
  * @date   25 November 2021
- * @brief  This is LSTMCellCore Layer Class of Neural Network
+ * @brief  These are lstm core functions.
  * @see    https://github.com/nnstreamer/nntrainer
  * @author hyeonseok lee <hs89.lee@samsung.com>
  * @bug    No known bugs except for NYI items
  *
  */
 
-#include <layer_context.h>
 #include <lstmcell_core.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
-#include <node_exporter.h>
-
-// ENABLE_SHARING_WT_IDX implies does the wt_idx of lstm_core can be shared
-// with lstm_cell variant layer.
-// Todo: remove this if sharing wt_idx with other lstm variant is enabled
-#define ENABLE_SHARING_WT_IDX 0
 
 namespace nntrainer {
 
-namespace init_lstm_context {
-void fillLayerInitContext(InitLayerContext &context,
-                          const InitLayerContext &core_context) {
-  /** real set the input flags */
-  auto const &input_dims = context.getInputDimensions();
-  for (unsigned int idx = 0; idx < core_context.getNumInputs(); idx++) {
-    context.setDynDimFlagInputDimension(idx, input_dims[idx].getDynDimFlag());
-    context.setEffDimFlagInputDimension(idx, input_dims[idx].getEffDimFlag());
-  }
-
-  /** real request of tensors */
-  for (auto const &ts : core_context.getTensorsSpec())
-    context.requestTensor(ts);
-
-  /** real request of weights */
-  for (auto const &ws : core_context.getWeightsSpec())
-    context.requestWeight(ws);
-}
-
-void fillWeights(std::vector<Weight> &weights, const RunLayerContext &context,
-                 bool training, const std::vector<unsigned int> &wt_idx,
-                 const unsigned int max_timestep, const unsigned int timestep,
-                 bool test) {
-  weights.resize(context.getNumWeights());
-  for (unsigned int i = 0; i < context.getNumWeights(); ++i) {
-    if (training && (!test || i < context.getNumWeights() - 2)) {
-      weights[i] =
-        Weight(context.getWeight(wt_idx[i]), context.getWeightGrad(wt_idx[i]),
-               context.getWeightName(wt_idx[i]));
-    } else {
-      weights[i] = Weight(context.getWeight(wt_idx[i]), Tensor(),
-                          context.getWeightName(wt_idx[i]));
-    }
-  }
-}
-
-const std::vector<Weight *> getWeights(std::vector<Weight> &weights) {
-  std::vector<Weight *> ret(weights.size());
-  for (unsigned int i = 0; i < weights.size(); ++i) {
-    ret[i] = &weights[i];
-  }
-  return ret;
-}
-
-void fillInputs(std::vector<Var_Grad> &inputs, RunLayerContext &context,
-                bool training, const std::vector<unsigned int> &wt_idx,
-                const unsigned int max_timestep, const unsigned int timestep) {
-  inputs.resize(3);
-  Tensor empty;
-  const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
-  const unsigned int batch_size = output_dim.batch();
-  const unsigned int unit = output_dim.width();
-
-  const Tensor &input = context.getInput(wt_idx[0]);
-  const Tensor &outgoing_derivative =
-    training ? context.getOutgoingDerivative(0) : empty;
-
-  Tensor &hidden_state = context.getTensor(wt_idx[1]);
-  hidden_state.reshape({max_timestep, 1, batch_size, unit});
-  Tensor &hidden_state_derivative =
-    training ? context.getTensorGrad(wt_idx[1]) : empty;
-  if (training) {
-    hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
-  }
-
-  Tensor &cell_state = context.getTensor(wt_idx[2]);
-  cell_state.reshape({max_timestep, 1, batch_size, unit});
-  Tensor &cell_state_derivative =
-    training ? context.getTensorGrad(wt_idx[2]) : empty;
-  if (training) {
-    cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
-  }
-
-  Tensor prev_hidden_state;
-  Tensor prev_hidden_state_derivative;
-  Tensor prev_cell_state;
-  Tensor prev_cell_state_derivative;
-  if (!timestep) {
-    prev_hidden_state = Tensor(batch_size, 1, 1, unit);
-    prev_hidden_state.setZero();
-    prev_hidden_state_derivative = Tensor(batch_size, 1, 1, unit);
-    prev_hidden_state_derivative.setZero();
-    prev_cell_state = Tensor(batch_size, 1, 1, unit);
-    prev_cell_state.setZero();
-    prev_cell_state_derivative = Tensor(batch_size, 1, 1, unit);
-    prev_cell_state_derivative.setZero();
-  } else {
-    prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1);
-    prev_hidden_state.reshape({batch_size, 1, 1, unit});
-    if (training) {
-      prev_hidden_state_derivative =
-        hidden_state_derivative.getBatchSlice(timestep - 1, 1);
-      prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
-    }
-    prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1);
-    prev_cell_state.reshape({batch_size, 1, 1, unit});
-    if (training) {
-      prev_cell_state_derivative =
-        cell_state_derivative.getBatchSlice(timestep - 1, 1);
-      prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
-    }
-  }
-
-  inputs[0] = Var_Grad(input, outgoing_derivative, "lstmcell_core input");
-  inputs[1] = Var_Grad(prev_hidden_state, prev_hidden_state_derivative,
-                       context.getTensorName(wt_idx[1]));
-  inputs[2] = Var_Grad(prev_cell_state, prev_cell_state_derivative,
-                       context.getTensorName(wt_idx[2]));
-}
-
-const std::vector<Var_Grad *> getInputs(std::vector<Var_Grad> &inputs) {
-  std::vector<Var_Grad *> ret(inputs.size());
-  for (unsigned int i = 0; i < inputs.size(); ++i) {
-    ret[i] = &inputs[i];
-  }
-  return ret;
-}
-
-void fillOutputs(std::vector<Var_Grad> &outputs, RunLayerContext &context,
-                 bool training, const std::vector<unsigned int> &wt_idx,
-                 const unsigned int max_timestep, const unsigned int timestep) {
-  outputs.resize(2);
-  Tensor empty;
-  const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
-  const unsigned int batch_size = output_dim.batch();
-  const unsigned int unit = output_dim.width();
-
-  Tensor &hidden_state = context.getTensor(wt_idx[1]);
-  hidden_state.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
-  next_hidden_state.reshape({batch_size, 1, 1, unit});
-  Tensor next_hidden_state_derivative;
-  if (training) {
-    Tensor &hidden_state_derivative = context.getTensorGrad(wt_idx[1]);
-    hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
-    next_hidden_state_derivative =
-      hidden_state_derivative.getBatchSlice(timestep, 1);
-    next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
-  }
-
-  Tensor &cell_state = context.getTensor(wt_idx[2]);
-  cell_state.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1);
-  next_cell_state.reshape({batch_size, 1, 1, unit});
-  Tensor next_cell_state_derivative;
-  if (training) {
-    Tensor &cell_state_derivative = context.getTensorGrad(wt_idx[2]);
-    cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
-    next_cell_state_derivative =
-      cell_state_derivative.getBatchSlice(timestep, 1);
-    next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
-  }
-
-  outputs[0] = Var_Grad(next_hidden_state, next_hidden_state_derivative,
-                        context.getTensorName(wt_idx[1]));
-  outputs[1] = Var_Grad(next_cell_state, next_cell_state_derivative,
-                        context.getTensorName(wt_idx[2]));
-}
-
-const std::vector<Var_Grad *> getOutputs(std::vector<Var_Grad> &outputs) {
-  std::vector<Var_Grad *> ret(outputs.size());
-  for (unsigned int i = 0; i < outputs.size(); ++i) {
-    ret[i] = &outputs[i];
-  }
-  return ret;
-}
-
-void fillTensors(std::vector<Var_Grad> &tensors, RunLayerContext &context,
-                 bool training, const std::vector<unsigned int> &wt_idx,
-                 const unsigned int max_timestep, const unsigned int timestep) {
-  tensors.resize(1);
-
-#if ENABLE_SHARING_WT_IDX
-  const Tensor &ifgo = context.getTensor(wt_idx[0]);
-  Tensor empty;
-  const Tensor &ifgo_derivative =
-    training ? context.getTensorGrad(wt_idx[0]) : empty;
-  tensors[0] =
-    Var_Grad(ifgo, ifgo_derivative, context.getTensorName(wt_idx[0]));
-#else
-  const TensorDim &output_dim = context.getOutput(0).getDim();
-  const unsigned int batch_size = output_dim.batch();
-  const unsigned int unit = output_dim.width();
-
-  Tensor &ifgo = context.getTensor(wt_idx[0]);
-  const unsigned int NUM_GATE = ifgo.width() / unit;
-  ifgo.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
-  Tensor ifgo_t = ifgo.getBatchSlice(timestep, 1);
-  ifgo_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
-  Tensor ifgo_derivative_t;
-  if (training) {
-    Tensor &ifgo_derivative = context.getTensorGrad(wt_idx[0]);
-    ifgo_derivative.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
-    ifgo_derivative_t = ifgo_derivative.getBatchSlice(timestep, 1);
-    ifgo_derivative_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
-  }
-  tensors[0] =
-    Var_Grad(ifgo_t, ifgo_derivative_t, context.getTensorName(wt_idx[0]));
-#endif
-}
-
-const std::vector<Var_Grad *> getTensors(std::vector<Var_Grad> &tensors) {
-  std::vector<Var_Grad *> ret(tensors.size());
-  for (unsigned int i = 0; i < tensors.size(); ++i) {
-    ret[i] = &tensors[i];
-  }
-  return ret;
-}
-
-} // namespace init_lstm_context
-
-enum INDEX {
-  INPUT = 0,
-  HIDDEN_STATE_IN = 1,
-  CELL_STATE_IN = 2,
-  HIDDEN_STATE_OUT = 0,
-  CELL_STATE_OUT = 1
-};
-
-enum LSTMCellCoreParams {
-  weight_ih,
-  weight_hh,
-  bias_h,
-  bias_ih,
-  bias_hh,
-  ifgo,
-};
-
-LSTMCellCoreLayer::LSTMCellCoreLayer() :
-  LayerImpl(),
-  lstmcell_core_props(
-    props::Unit(), props::HiddenStateActivation() = ActivationType::ACT_TANH,
-    props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
-    props::IntegrateBias()),
-  acti_func(ActivationType::ACT_NONE, true),
-  recurrent_acti_func(ActivationType::ACT_NONE, true) {
-  wt_idx.fill(std::numeric_limits<unsigned>::max());
-}
-
-void LSTMCellCoreLayer::finalize(InitLayerContext &context) {
-#if ENBABLE_SHARING_WEIGHT
-  const Tensor::Initializer weight_initializer =
-    std::get<props::WeightInitializer>(*layer_impl_props).get();
-  const Tensor::Initializer bias_initializer =
-    std::get<props::BiasInitializer>(*layer_impl_props).get();
-  const WeightRegularizer weight_regularizer =
-    std::get<props::WeightRegularizer>(*layer_impl_props).get();
-  const float weight_regularizer_constant =
-    std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
-  const bool disable_bias =
-    std::get<props::DisableBias>(*layer_impl_props).get();
-#endif
-
-  NNTR_THROW_IF(std::get<props::Unit>(lstmcell_core_props).empty(),
-                std::invalid_argument)
-    << "unit property missing for lstmcell_core layer";
-  const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
-  const ActivationType hidden_state_activation_type =
-    std::get<props::HiddenStateActivation>(lstmcell_core_props).get();
-  const ActivationType recurrent_activation_type =
-    std::get<props::RecurrentActivation>(lstmcell_core_props).get();
-#if ENBABLE_SHARING_WEIGHT
-  const bool integrate_bias =
-    std::get<props::IntegrateBias>(lstmcell_core_props).get();
-#endif
-
-  if (context.getNumInputs() != 3)
-    throw std::invalid_argument("LSTMCellCore layer should takes 3 input");
-
-  // input_dim = [ batch, 1, 1, feature_size ]
-  const TensorDim &input_dim = context.getInputDimensions()[0];
-  if (input_dim.height() != 1 || input_dim.channel() != 1)
-    throw std::invalid_argument(
-      "Input must be single time dimension for LSTMCellCore");
-  // input_hidden_state_dim = [ batch, 1, 1, unit ]
-  const TensorDim &input_hidden_state_dim = context.getInputDimensions()[1];
-  if (input_hidden_state_dim.channel() != 1 ||
-      input_hidden_state_dim.height() != 1)
-    throw std::invalid_argument("Input hidden state's dimension should be "
-                                "[batch, 1, 1, unit] for LSTMCellCore");
-  // input_cell_state_dim = [ batch, 1, 1, unit ]
-  const TensorDim &input_cell_state_dim = context.getInputDimensions()[2];
-  if (input_cell_state_dim.channel() != 1 || input_cell_state_dim.height() != 1)
-    throw std::invalid_argument("Input cell state's dimension should be "
-                                "[batch, 1, 1, unit] for LSTMCellCore");
-
-  const unsigned int batch_size = input_dim.batch();
-#if ENABLE_SHARING_WT_IDX
-  const unsigned int feature_size = input_dim.width();
-#endif
-
-  const TensorDim output_dim(batch_size, 1, 1, unit);
-  const TensorDim output_hidden_state_dim = input_hidden_state_dim;
-  const TensorDim output_cell_state_dim = input_cell_state_dim;
-
-  context.setOutputDimensions(
-    {output_dim, output_hidden_state_dim, output_cell_state_dim});
-
-#if ENABLE_SHARING_WT_IDX
-  // weight_initializer can be set seperately. weight_ih initializer,
-  // weight_hh initializer kernel initializer & recurrent_initializer in keras
-  // for now, it is set same way.
-
-  // - weight_ih ( input to hidden )
-  //  : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o
-  TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
-  wt_idx[LSTMCellCoreParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
-  // - weight_hh ( hidden to hidden )
-  //  : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o
-  TensorDim weight_hh_dim({unit, NUM_GATE * unit});
-  wt_idx[LSTMCellCoreParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
-  if (!disable_bias) {
-    if (integrate_bias) {
-      // - bias_h ( input bias, hidden bias are integrate to 1 bias )
-      //  : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
-      TensorDim bias_h_dim({NUM_GATE * unit});
-      wt_idx[LSTMCellCoreParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
-    } else {
-      // - bias_ih ( input bias )
-      //  : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
-      TensorDim bias_ih_dim({NUM_GATE * unit});
-      wt_idx[LSTMCellCoreParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
-      // - bias_hh ( hidden bias )
-      //  : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
-      TensorDim bias_hh_dim({NUM_GATE * unit});
-      wt_idx[LSTMCellCoreParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
-    }
-  }
-#endif
-
-#if ENABLE_SHARING_WT_IDX
-  TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
-  wt_idx[LSTMCellCoreParams::ifgo] =
-    context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
-                          TensorLifespan::ITERATION_LIFESPAN);
-#endif
-  acti_func.setActiFunc(hidden_state_activation_type);
-  recurrent_acti_func.setActiFunc(recurrent_activation_type);
-}
-
-void LSTMCellCoreLayer::setProperty(const std::vector<std::string> &values) {
-  std::vector<std::string> remain_props =
-    loadProperties(values, lstmcell_core_props);
-  LayerImpl::setProperty(remain_props);
-}
-
-void LSTMCellCoreLayer::exportTo(Exporter &exporter,
-                                 const ExportMethods &method) const {
-#if ENABLE_SHARING_WT_IDX
-  LayerImpl::exportTo(exporter, method);
-#endif
-  exporter.saveResult(lstmcell_core_props, method, this);
-}
-
-void LSTMCellCoreLayer::forwarding(RunLayerContext &context, bool training) {
-  const bool disable_bias =
-    std::get<props::DisableBias>(*layer_impl_props).get();
-
-  const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
-  const bool integrate_bias =
-    std::get<props::IntegrateBias>(lstmcell_core_props).get();
-
-  const Tensor &input = context.getInput(INDEX::INPUT);
-  const unsigned int batch_size = input.getDim().batch();
-
-  const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
-  const Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
-  Tensor &next_hidden_state = context.getOutput(INDEX::HIDDEN_STATE_OUT);
-  Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
-
-#if ENABLE_SHARING_WT_IDX
-  const Tensor &weight_ih =
-    context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
-  const Tensor &weight_hh =
-    context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
-  Tensor empty;
-  Tensor &bias_h = !disable_bias && integrate_bias
-                     ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_h])
-                     : empty;
-  Tensor &bias_ih = !disable_bias && !integrate_bias
-                      ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_ih])
-                      : empty;
-  Tensor &bias_hh = !disable_bias && !integrate_bias
-                      ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_hh])
-                      : empty;
-#else
-  const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
-  const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
-  Tensor empty;
-  Tensor &bias_h = !disable_bias && integrate_bias
-                     ? context.getWeight(LSTMCellCoreParams::bias_h)
-                     : empty;
-  // subtract index by 1 cause there is no bias_h
-  Tensor &bias_ih = !disable_bias && !integrate_bias
-                      ? context.getWeight(LSTMCellCoreParams::bias_ih - 1)
-                      : empty;
-  Tensor &bias_hh = !disable_bias && !integrate_bias
-                      ? context.getWeight(LSTMCellCoreParams::bias_hh - 1)
-                      : empty;
-#endif
-
-#if ENABLE_SHARING_WT_IDX
-  Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
-#else
-  Tensor &ifgo = context.getTensor(0);
-#endif
-
-  input.dot(weight_ih, ifgo);
-  prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0);
-  if (!disable_bias) {
-    if (integrate_bias) {
-      ifgo.add_i(bias_h);
-    } else {
-      ifgo.add_i(bias_ih);
-      ifgo.add_i(bias_hh);
-    }
-  }
-
-  Tensor input_forget_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
-  Tensor input_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
-  Tensor forget_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
-  Tensor memory_cell =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
-  Tensor output_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
-
-  recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate);
-  recurrent_acti_func.run_fn(output_gate, output_gate);
-  acti_func.run_fn(memory_cell, memory_cell);
-
-  forget_gate.multiply_strided(prev_cell_state, next_cell_state);
-  memory_cell.multiply_strided(input_gate, next_cell_state, 1.0f);
-
-  acti_func.run_fn(next_cell_state, next_hidden_state);
-  next_hidden_state.multiply_i_strided(output_gate);
-}
-
-void LSTMCellCoreLayer::calcDerivative(RunLayerContext &context) {
-#if ENABLE_SHARING_WT_IDX
-  Tensor &ifgo_derivative =
-    context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
-  const Tensor &weight_ih =
-    context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
-#else
-  const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
-  Tensor &ifgo_derivative = context.getTensorGrad(0);
-#endif
-  Tensor &outgoing_derivative = context.getOutgoingDerivative(INDEX::INPUT);
-
-  ifgo_derivative.dot(weight_ih, outgoing_derivative, false, true);
-}
-
-void LSTMCellCoreLayer::calcGradient(RunLayerContext &context) {
-  const bool disable_bias =
-    std::get<props::DisableBias>(*layer_impl_props).get();
-
-  const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
-  const bool integrate_bias =
-    std::get<props::IntegrateBias>(lstmcell_core_props).get();
-
-  const Tensor &input = context.getInput(INDEX::INPUT);
-  const unsigned int batch_size = input.getDim().batch();
-
-#if ENABLE_SHARING_WT_IDX
-  Tensor &djdweight_ih =
-    context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_ih]);
-  const Tensor &weight_hh =
-    context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
-  Tensor &djdweight_hh =
-    context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_hh]);
-  Tensor empty;
-  Tensor &djdbias_h =
-    !disable_bias && integrate_bias
-      ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_h])
-      : empty;
-  Tensor &djdbias_ih =
-    !disable_bias && !integrate_bias
-      ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_ih])
-      : empty;
-  Tensor &djdbias_hh =
-    !disable_bias && !integrate_bias
-      ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_hh])
-      : empty;
-#else
-  Tensor &djdweight_ih = context.getWeightGrad(LSTMCellCoreParams::weight_ih);
-  const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
-  Tensor &djdweight_hh = context.getWeightGrad(LSTMCellCoreParams::weight_hh);
-  Tensor empty;
-  Tensor &djdbias_h = !disable_bias && integrate_bias
-                        ? context.getWeightGrad(LSTMCellCoreParams::bias_h)
-                        : empty;
-  // subtract index by 1 cause there is no bias_h(and also djdbias_h)
-  Tensor &djdbias_ih =
-    !disable_bias && !integrate_bias
-      ? context.getWeightGrad(LSTMCellCoreParams::bias_ih - 1)
-      : empty;
-  Tensor &djdbias_hh =
-    !disable_bias && !integrate_bias
-      ? context.getWeightGrad(LSTMCellCoreParams::bias_hh - 1)
-      : empty;
-#endif
-
-  const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
-  Tensor &prev_hidden_state_derivative =
-    context.getOutgoingDerivative(INDEX::HIDDEN_STATE_IN);
-  Tensor &next_hidden_state_derivative =
-    context.getIncomingDerivative(INDEX::HIDDEN_STATE_OUT);
-
-  Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
-  Tensor &prev_cell_state_derivative =
-    context.getOutgoingDerivative(INDEX::CELL_STATE_IN);
-  Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
-  Tensor &next_cell_state_derivative =
-    context.getIncomingDerivative(INDEX::CELL_STATE_OUT);
-
-#if ENABLE_SHARING_WT_IDX
-  Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
-  Tensor &ifgo_derivative =
-    context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
-#else
-  Tensor &ifgo = context.getTensor(0);
-  Tensor &ifgo_derivative = context.getTensorGrad(0);
-#endif
-
-  Tensor input_forget_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
-  Tensor input_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
-  Tensor forget_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
-  Tensor memory_cell =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
-  Tensor output_gate =
-    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
-
-  Tensor input_forget_gate_derivative =
-    ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
-  Tensor input_gate_derivative =
-    ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
-  Tensor forget_gate_derivative =
-    ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
-  Tensor memory_cell_derivative = ifgo_derivative.getSharedDataTensor(
-    {batch_size, 1, 1, unit}, unit * 2, false);
-  Tensor output_gate_derivative = ifgo_derivative.getSharedDataTensor(
-    {batch_size, 1, 1, unit}, unit * 3, false);
-
-  acti_func.run_fn(next_cell_state, next_cell_state);
-  next_hidden_state_derivative.multiply_strided(next_cell_state,
-                                                output_gate_derivative);
-
-  acti_func.run_prime_fn(next_cell_state, prev_cell_state_derivative,
-                         next_hidden_state_derivative);
-  prev_cell_state_derivative.multiply_i_strided(output_gate);
-  prev_cell_state_derivative.add_i(next_cell_state_derivative);
-
-  prev_cell_state_derivative.multiply_strided(input_gate,
-                                              memory_cell_derivative);
-  prev_cell_state_derivative.multiply_strided(memory_cell,
-                                              input_gate_derivative);
-
-  prev_cell_state_derivative.multiply_strided(prev_cell_state,
-                                              forget_gate_derivative);
-  prev_cell_state_derivative.multiply_i_strided(forget_gate);
-
-  recurrent_acti_func.run_prime_fn(output_gate, output_gate_derivative,
-                                   output_gate_derivative);
-  recurrent_acti_func.run_prime_fn(input_forget_gate,
-                                   input_forget_gate_derivative,
-                                   input_forget_gate_derivative);
-  acti_func.run_prime_fn(memory_cell, memory_cell_derivative,
-                         memory_cell_derivative);
-
-  if (!disable_bias) {
-    if (integrate_bias) {
-      ifgo_derivative.sum(0, djdbias_h, 1.0f, 1.0f);
-    } else {
-      ifgo_derivative.sum(0, djdbias_ih, 1.0f, 1.0f);
-      ifgo_derivative.sum(0, djdbias_hh, 1.0f, 1.0f);
-    }
-  }
-  input.dot(ifgo_derivative, djdweight_ih, true, false, 1.0f);
-  prev_hidden_state.dot(ifgo_derivative, djdweight_hh, true, false, 1.0f);
-  ifgo_derivative.dot(weight_hh, prev_hidden_state_derivative, false, true);
-}
-
-void LSTMCellCoreLayer::setBatch(RunLayerContext &context, unsigned int batch) {
-  context.updateTensor(wt_idx[LSTMCellCoreParams::ifgo], batch);
-}
-
 void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
                          const bool disable_bias, const bool integrate_bias,
                          ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
index 4f02398..cdb33ed 100644 (file)
@@ -4,7 +4,7 @@
  *
  * @file   lstmcell_core.h
  * @date   25 November 2021
- * @brief  This is LSTMCellCore Layer Class of Neural Network
+ * @brief  These are lstm core functions.
  * @see           https://github.com/nnstreamer/nntrainer
  * @author hyeonseok lee <hs89.lee@samsung.com>
  * @bug    No known bugs except for NYI items
 #ifdef __cplusplus
 
 #include <acti_func.h>
-#include <common_properties.h>
-#include <layer_impl.h>
 
 namespace nntrainer {
 
-namespace init_lstm_context {
-void fillLayerInitContext(InitLayerContext &context,
-                          const InitLayerContext &core_context);
-void fillWeights(std::vector<Weight> &weights, const RunLayerContext &context,
-                 bool training, const std::vector<unsigned int> &wt_idx,
-                 const unsigned int max_timestep, const unsigned int timestep,
-                 bool test = false);
-const std::vector<Weight *> getWeights(std::vector<Weight> &weights);
-void fillInputs(std::vector<Var_Grad> &inputs, RunLayerContext &context,
-                bool training, const std::vector<unsigned int> &wt_idx,
-                const unsigned int max_timestep, const unsigned int timestep);
-const std::vector<Var_Grad *> getInputs(std::vector<Var_Grad> &inputs);
-void fillOutputs(std::vector<Var_Grad> &outputs, RunLayerContext &context,
-                 bool training, const std::vector<unsigned int> &wt_idx,
-                 const unsigned int max_timestep, const unsigned int timestep);
-const std::vector<Var_Grad *> getOutputs(std::vector<Var_Grad> &outputs);
-void fillTensors(std::vector<Var_Grad> &tensors, RunLayerContext &context,
-                 bool training, const std::vector<unsigned int> &wt_idx,
-                 const unsigned int max_timestep, const unsigned int timestep);
-const std::vector<Var_Grad *> getTensors(std::vector<Var_Grad> &tensors);
-} // namespace init_lstm_context
-
-/**
- * @class   LSTMCellCoreLayer
- * @brief   LSTMCellCoreLayer
- */
-class LSTMCellCoreLayer : public LayerImpl {
-public:
-  /**
-   * @brief     Constructor of LSTMCellLayer
-   */
-  LSTMCellCoreLayer();
-
-  /**
-   * @brief     Destructor of LSTMCellLayer
-   */
-  ~LSTMCellCoreLayer() = default;
-
-  /**
-   * @copydoc Layer::finalize(InitLayerContext &context)
-   */
-  void finalize(InitLayerContext &context) override;
-
-  /**
-   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
-   */
-  void forwarding(RunLayerContext &context, bool training) override;
-
-  /**
-   * @copydoc Layer::calcDerivative(RunLayerContext &context)
-   */
-  void calcDerivative(RunLayerContext &context) override;
-
-  /**
-   * @copydoc Layer::calcGradient(RunLayerContext &context)
-   */
-  void calcGradient(RunLayerContext &context) override;
-  /**
-   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
-   */
-  void exportTo(Exporter &exporter, const ExportMethods &method) const override;
-
-  /**
-   * @copydoc Layer::getType()
-   */
-  const std::string getType() const override {
-    return LSTMCellCoreLayer::type;
-  };
-
-  /**
-   * @copydoc Layer::supportBackwarding()
-   */
-  bool supportBackwarding() const override { return true; }
-
-  /**
-   * @copydoc Layer::setProperty(const PropertyType type, const std::string
-   * &value)
-   */
-  void setProperty(const std::vector<std::string> &values) override;
-
-  /**
-   * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
-   */
-  void setBatch(RunLayerContext &context, unsigned int batch) override;
-
-  inline static const std::string type = "lstmcell_core";
-
-private:
-  static constexpr unsigned int NUM_GATE = 4;
-
-  /**
-   * Unit: number of output neurons
-   * HiddenStateActivation: activation type for hidden state. default is tanh
-   * RecurrentActivation: activation type for recurrent. default is sigmoid
-   * IntegrateBias: integrate bias_ih, bias_hh to bias_h
-   *
-   * */
-  std::tuple<props::Unit, props::HiddenStateActivation,
-             props::RecurrentActivation, props::IntegrateBias>
-    lstmcell_core_props;
-  std::array<unsigned int, 6> wt_idx; /**< indices of the weights */
-
-  /**
-   * @brief     activation function for h_t : default is tanh
-   */
-  ActiFunc acti_func;
-
-  /**
-   * @brief     activation function for recurrent : default is sigmoid
-   */
-  ActiFunc recurrent_acti_func;
-};
-
 void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
                          const bool disable_bias, const bool integrate_bias,
                          ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
index 5c0c3d6..1f5e01a 100644 (file)
@@ -32,63 +32,21 @@ enum ZoneoutLSTMParams {
   hidden_state,
   cell_state,
   ifgo,
+  lstm_cell_state,
   hidden_state_zoneout_mask,
   cell_state_zoneout_mask,
 };
 
-unsigned int hidden_state_origin_idx = 0, cell_state_origin_idx = 0;
-
-const std::vector<unsigned int>
-getWeightIdx(std::array<unsigned int, 10> &wt_idx, const bool disable_bias,
-             const bool integrate_bias, const bool test) {
-  std::vector<unsigned int> ret;
-  ret.push_back(wt_idx[ZoneoutLSTMParams::weight_ih]);
-  ret.push_back(wt_idx[ZoneoutLSTMParams::weight_hh]);
-  if (!disable_bias) {
-    if (integrate_bias) {
-      ret.push_back(wt_idx[ZoneoutLSTMParams::bias_h]);
-    } else {
-      ret.push_back(wt_idx[ZoneoutLSTMParams::bias_ih]);
-      ret.push_back(wt_idx[ZoneoutLSTMParams::bias_hh]);
-    }
-  }
-  if (test) {
-    ret.push_back(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
-    ret.push_back(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
-  }
-  return ret;
-}
-
-const std::vector<unsigned int>
-getInputIdx(std::array<unsigned int, 10> &wt_idx) {
-  std::vector<unsigned int> ret(3);
-  ret[0] = SINGLE_INOUT_IDX;
-  ret[1] = wt_idx[ZoneoutLSTMParams::hidden_state];
-  ret[2] = wt_idx[ZoneoutLSTMParams::cell_state];
-  return ret;
-}
-
-const std::vector<unsigned int>
-getOutputIdx(std::array<unsigned int, 10> &wt_idx) {
-  std::vector<unsigned int> ret(3);
-  ret[0] = SINGLE_INOUT_IDX;
-  ret[1] = hidden_state_origin_idx;
-  ret[2] = cell_state_origin_idx;
-  return ret;
-}
-
-const std::vector<unsigned int>
-getTensorIdx(std::array<unsigned int, 10> &wt_idx) {
-  std::vector<unsigned int> ret(1);
-  ret[0] = wt_idx[ZoneoutLSTMParams::ifgo];
-  return ret;
-}
-
 ZoneoutLSTMCellLayer::ZoneoutLSTMCellLayer() :
   LayerImpl(),
-  zoneout_lstmcell_props(props::Unit(), HiddenStateZoneOutRate(),
-                         CellStateZoneOutRate(), props::IntegrateBias(), Test(),
-                         props::MaxTimestep(), props::Timestep()),
+  zoneout_lstmcell_props(
+    props::Unit(), props::IntegrateBias(),
+    props::HiddenStateActivation() = ActivationType::ACT_TANH,
+    props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
+    HiddenStateZoneOutRate(), CellStateZoneOutRate(), Test(),
+    props::MaxTimestep(), props::Timestep()),
+  acti_func(ActivationType::ACT_NONE, true),
+  recurrent_acti_func(ActivationType::ACT_NONE, true),
   epsilon(1e-3) {
   wt_idx.fill(std::numeric_limits<unsigned>::max());
 }
@@ -112,7 +70,6 @@ bool ZoneoutLSTMCellLayer::CellStateZoneOutRate::isValid(
 }
 
 void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
-#if !ENABLE_SHARING_WT_IDX
   const Tensor::Initializer weight_initializer =
     std::get<props::WeightInitializer>(*layer_impl_props).get();
   const Tensor::Initializer bias_initializer =
@@ -123,7 +80,6 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
-#endif
 
   NNTR_THROW_IF(std::get<props::Unit>(zoneout_lstmcell_props).empty(),
                 std::invalid_argument)
@@ -131,6 +87,10 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
   const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
   const bool integrate_bias =
     std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
+  const ActivationType hidden_state_activation_type =
+    std::get<props::HiddenStateActivation>(zoneout_lstmcell_props).get();
+  const ActivationType recurrent_activation_type =
+    std::get<props::RecurrentActivation>(zoneout_lstmcell_props).get();
   const bool test = std::get<Test>(zoneout_lstmcell_props).get();
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
@@ -146,20 +106,17 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
 
   // input_dim = [ batch_size, 1, 1, feature_size ]
   const TensorDim &input_dim = context.getInputDimensions()[0];
-  if (input_dim.height() != 1 || input_dim.channel() != 1)
+  if (input_dim.channel() != 1 || input_dim.height() != 1)
     throw std::invalid_argument("Input must be single time dimension for "
                                 "ZoneoutLSTMCell (shape should be "
                                 "[batch_size, 1, 1, feature_size])");
   const unsigned int batch_size = input_dim.batch();
-#if !ENABLE_SHARING_WT_IDX
   const unsigned int feature_size = input_dim.width();
-#endif
 
   // output_dim = [ batch_size, 1, 1, unit ]
   const TensorDim output_dim(batch_size, 1, 1, unit);
   context.setOutputDimensions({output_dim});
 
-#if !ENABLE_SHARING_WT_IDX
   // weight_initializer can be set seperately. weight_ih initializer,
   // weight_hh initializer kernel initializer & recurrent_initializer in keras
   // for now, it is set same way.
@@ -199,7 +156,6 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
                               WeightRegularizer::NONE, 1.0f, "bias_hh", true);
     }
   }
-#endif
 
   /**
    * TODO: hidden_state is only used from the previous timestep. Once it is
@@ -216,24 +172,17 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
     cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
     TensorLifespan::ITERATION_LIFESPAN, false);
 
-  hidden_state_origin_idx = context.requestTensor(
-    hidden_state_dim, "hidden_state_origin", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
-  cell_state_origin_idx = context.requestTensor(
-    cell_state_dim, "cell_state_origin", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
-
-#if !ENABLE_SHARING_WT_IDX
-  /**
-   * TODO: make this independent of time dimension once recurrent realizer
-   * supports requesting tensors which are not always shared
-   */
-  /** ifgo_dim = [ max_timestep * batch_size, 1, 1, NUM_GATE * unit ] */
-  const TensorDim ifgo_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit);
+  /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
+  const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
   wt_idx[ZoneoutLSTMParams::ifgo] =
     context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
-                          TensorLifespan::ITERATION_LIFESPAN, false);
-#endif
+                          TensorLifespan::ITERATION_LIFESPAN);
+
+  /** lstm_cell_state_dim = [ batch_size, 1, 1, unit ] */
+  const TensorDim lstm_cell_state_dim(batch_size, 1, 1, unit);
+  wt_idx[ZoneoutLSTMParams::lstm_cell_state] = context.requestTensor(
+    lstm_cell_state_dim, "lstm_cell_state", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN);
 
   // hidden_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
   const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
@@ -262,55 +211,20 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
       Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
   }
 
-  TensorDim hidden_state_t_dim({batch_size, 1, 1, unit});
-  TensorDim cell_state_t_dim({batch_size, 1, 1, unit});
-  InitLayerContext core_context(
-    {input_dim, hidden_state_t_dim, cell_state_t_dim}, 3,
-    context.executeInPlace(), context.getName());
-  lstmcellcorelayer.finalize(core_context);
-  init_lstm_context::fillLayerInitContext(context, core_context);
+  acti_func.setActiFunc(hidden_state_activation_type);
+  recurrent_acti_func.setActiFunc(recurrent_activation_type);
 }
 
 void ZoneoutLSTMCellLayer::setProperty(const std::vector<std::string> &values) {
-  std::vector<std::string> remain_props =
+  const std::vector<std::string> &remain_props =
     loadProperties(values, zoneout_lstmcell_props);
-
-  // Note: In current implementation the lstmcellcorelayer also has
-  // a properties related to weight. But it is not exported or used anywhere.
-  lstmcellcorelayer.setProperty(remain_props);
-  if (!std::get<props::Unit>(zoneout_lstmcell_props).empty()) {
-    lstmcellcorelayer.setProperty(
-      {"unit=" + to_string(std::get<props::Unit>(zoneout_lstmcell_props))});
-  }
-  lstmcellcorelayer.setProperty(
-    {"integrate_bias=" +
-     to_string(std::get<props::IntegrateBias>(zoneout_lstmcell_props))});
-
-#if !ENABLE_SHARING_WT_IDX
-  // To remove lstmcell core layer's properties
-  std::tuple<props::HiddenStateActivation, props::RecurrentActivation>
-    lstmcell_core_props;
-  std::vector<std::string> impl_props =
-    loadProperties(remain_props, lstmcell_core_props);
-
-  LayerImpl::setProperty(impl_props);
-#endif
+  LayerImpl::setProperty(remain_props);
 }
 
 void ZoneoutLSTMCellLayer::exportTo(Exporter &exporter,
                                     const ExportMethods &method) const {
-#if !ENABLE_SHARING_WT_IDX
   LayerImpl::exportTo(exporter, method);
-#endif
-  exporter.saveResult(
-    std::forward_as_tuple(
-      std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props),
-      std::get<CellStateZoneOutRate>(zoneout_lstmcell_props),
-      std::get<Test>(zoneout_lstmcell_props),
-      std::get<props::MaxTimestep>(zoneout_lstmcell_props),
-      std::get<props::Timestep>(zoneout_lstmcell_props)),
-    method, this);
-  lstmcellcorelayer.exportTo(exporter, method);
+  exporter.saveResult(zoneout_lstmcell_props, method, this);
 }
 
 void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
@@ -318,169 +232,128 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
     std::get<props::DisableBias>(*layer_impl_props).get();
 
   const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
+  const bool integrate_bias =
+    std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
   const float hidden_state_zoneout_rate =
     std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props).get();
   const float cell_state_zoneout_rate =
     std::get<CellStateZoneOutRate>(zoneout_lstmcell_props).get();
-  const bool integrate_bias =
-    std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
   const bool test = std::get<Test>(zoneout_lstmcell_props).get();
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
   const unsigned int timestep =
     std::get<props::Timestep>(zoneout_lstmcell_props).get();
 
-  const unsigned int batch_size =
-    context.getInput(SINGLE_INOUT_IDX).getDim().batch();
-
+  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
 
-  Tensor &hidden_state =
-    context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
-  hidden_state.reshape({max_timestep, 1, batch_size, unit});
+  const unsigned int batch_size = input.getDim().batch();
+
+  const Tensor &weight_ih =
+    context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
+  const Tensor &weight_hh =
+    context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
+  Tensor empty;
+  Tensor &bias_h = !disable_bias && integrate_bias
+                     ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_h])
+                     : empty;
+  Tensor &bias_ih = !disable_bias && !integrate_bias
+                      ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_ih])
+                      : empty;
+  Tensor &bias_hh = !disable_bias && !integrate_bias
+                      ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_hh])
+                      : empty;
+
+  Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
+  hs.reshape({max_timestep, 1, batch_size, unit});
   Tensor prev_hidden_state;
   if (!timestep) {
     prev_hidden_state = Tensor(batch_size, 1, 1, unit);
     prev_hidden_state.setZero();
   } else {
-    prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1);
+    prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
     prev_hidden_state.reshape({batch_size, 1, 1, unit});
   }
-  Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
-  next_hidden_state.reshape({batch_size, 1, 1, unit});
+  Tensor hidden_state = hs.getBatchSlice(timestep, 1);
+  hidden_state.reshape({batch_size, 1, 1, unit});
 
-  Tensor &cell_state = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
-  cell_state.reshape({max_timestep, 1, batch_size, unit});
+  Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
+  cs.reshape({max_timestep, 1, batch_size, unit});
   Tensor prev_cell_state;
   if (!timestep) {
     prev_cell_state = Tensor(batch_size, 1, 1, unit);
     prev_cell_state.setZero();
   } else {
-    prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1);
+    prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
     prev_cell_state.reshape({batch_size, 1, 1, unit});
   }
-  Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1);
-  next_cell_state.reshape({batch_size, 1, 1, unit});
+  Tensor cell_state = cs.getBatchSlice(timestep, 1);
+  cell_state.reshape({batch_size, 1, 1, unit});
 
-  if (!timestep) {
-    hidden_state.setZero();
-    cell_state.setZero();
-  }
+  Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
 
-  init_lstm_context::fillWeights(
-    weights, context, training,
-    getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep,
-    timestep, test);
-  init_lstm_context::fillInputs(inputs, context, training, getInputIdx(wt_idx),
-                                max_timestep, timestep);
-  init_lstm_context::fillOutputs(outputs, context, training,
-                                 getOutputIdx(wt_idx), max_timestep, timestep);
-  init_lstm_context::fillTensors(tensors, context, training,
-                                 getTensorIdx(wt_idx), max_timestep, timestep);
-  RunLayerContext core_context(context.getName(), context.getTrainable(),
-                               context.getLoss(), context.executeInPlace(),
-                               init_lstm_context::getWeights(weights),
-                               init_lstm_context::getInputs(inputs),
-                               init_lstm_context::getOutputs(outputs),
-                               init_lstm_context::getTensors(tensors));
-  lstmcellcorelayer.forwarding(core_context, training);
+  Tensor &lstm_cell_state =
+    context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
+
+  lstmcell_forwarding(unit, batch_size, disable_bias, integrate_bias, acti_func,
+                      recurrent_acti_func, input, prev_hidden_state,
+                      prev_cell_state, hidden_state, lstm_cell_state, weight_ih,
+                      weight_hh, bias_h, bias_ih, bias_hh, ifgo);
 
   if (training) {
-    Tensor &hidden_state_zoneout_mask =
+    Tensor &hs_zoneout_mask =
       test ? context.getWeight(
                wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
            : context.getTensor(
                wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
-    hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_hidden_state_zoneout_mask =
-      hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
-    next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+    hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor hidden_state_zoneout_mask =
+      hs_zoneout_mask.getBatchSlice(timestep, 1);
+    hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
     Tensor prev_hidden_state_zoneout_mask;
     if (!test) {
       prev_hidden_state_zoneout_mask =
-        next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+        hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
     } else {
-      next_hidden_state_zoneout_mask.multiply(-1.0f,
-                                              prev_hidden_state_zoneout_mask);
+      hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask);
       prev_hidden_state_zoneout_mask.add_i(1.0f);
     }
 
-    Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx);
-    hidden_state_origin.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_hidden_state_origin =
-      hidden_state_origin.getBatchSlice(timestep, 1);
-    next_hidden_state_origin.reshape({batch_size, 1, 1, unit});
+    hidden_state.multiply_i(hidden_state_zoneout_mask);
+    prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, hidden_state,
+                               1.0f);
 
-    next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask,
-                                      next_hidden_state);
-    prev_hidden_state.multiply(prev_hidden_state_zoneout_mask,
-                               next_hidden_state, 1.0f);
-  }
-
-  if (training) {
-    Tensor &cell_state_zoneout_mask =
+    Tensor &cs_zoneout_mask =
       test
         ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
         : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
-    cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_cell_state_zoneout_mask =
-      cell_state_zoneout_mask.getBatchSlice(timestep, 1);
-    next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+    cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
+    cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
     Tensor prev_cell_state_zoneout_mask;
     if (!test) {
       prev_cell_state_zoneout_mask =
-        next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+        cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
     } else {
-      next_cell_state_zoneout_mask.multiply(-1.0f,
-                                            prev_cell_state_zoneout_mask);
+      cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
       prev_cell_state_zoneout_mask.add_i(1.0f);
     }
 
-    Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx);
-    cell_state_origin.reshape({max_timestep, 1, batch_size, unit});
-    Tensor next_cell_state_origin =
-      cell_state_origin.getBatchSlice(timestep, 1);
-    next_cell_state_origin.reshape({batch_size, 1, 1, unit});
-
-    next_cell_state_origin.multiply(next_cell_state_zoneout_mask,
-                                    next_cell_state);
-    prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state,
-                             1.0f);
+    lstm_cell_state.multiply(cell_state_zoneout_mask, cell_state);
+    prev_cell_state.multiply(prev_cell_state_zoneout_mask, cell_state, 1.0f);
   }
   // Todo: zoneout at inference
 
-  output.copyData(next_hidden_state);
+  output.copyData(hidden_state);
 }
 
 void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
-  const bool disable_bias =
-    std::get<props::DisableBias>(*layer_impl_props).get();
+  Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
+  const Tensor &weight_ih =
+    context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
+  Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
 
-  const bool integrate_bias =
-    std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
-  const bool test = std::get<Test>(zoneout_lstmcell_props).get();
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
-  const unsigned int timestep =
-    std::get<props::Timestep>(zoneout_lstmcell_props).get();
-
-  init_lstm_context::fillWeights(
-    weights, context, true,
-    getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep,
-    timestep, test);
-  init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx),
-                                max_timestep, timestep);
-  init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx),
-                                 max_timestep, timestep);
-  init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
-                                 max_timestep, timestep);
-  RunLayerContext core_context(context.getName(), context.getTrainable(),
-                               context.getLoss(), context.executeInPlace(),
-                               init_lstm_context::getWeights(weights),
-                               init_lstm_context::getInputs(inputs),
-                               init_lstm_context::getOutputs(outputs),
-                               init_lstm_context::getTensors(tensors));
-  lstmcellcorelayer.calcDerivative(core_context);
+  lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
 }
 
 void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
@@ -488,176 +361,185 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
     std::get<props::DisableBias>(*layer_impl_props).get();
 
   const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
-  const float hidden_state_zoneout_rate =
-    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
-  const float cell_state_zoneout_rate =
-    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
   const bool integrate_bias =
     std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
-  const bool test = std::get<Test>(zoneout_lstmcell_props);
+  const float hidden_state_zoneout_rate =
+    std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props).get();
+  const float cell_state_zoneout_rate =
+    std::get<CellStateZoneOutRate>(zoneout_lstmcell_props).get();
+  const bool test = std::get<Test>(zoneout_lstmcell_props).get();
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
   const unsigned int timestep =
     std::get<props::Timestep>(zoneout_lstmcell_props).get();
 
-  unsigned int batch_size = context.getInput(SINGLE_INOUT_IDX).getDim().batch();
-
+  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
   const Tensor &incoming_derivative =
     context.getIncomingDerivative(SINGLE_INOUT_IDX);
 
-  Tensor &hidden_state_derivative =
-    context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]);
-  hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_hidden_state_derivative =
-    hidden_state_derivative.getBatchSlice(timestep, 1);
-  next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+  unsigned int batch_size = input.getDim().batch();
+
+  Tensor &d_weight_ih =
+    context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]);
+  const Tensor &weight_hh =
+    context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
+  Tensor &d_weight_hh =
+    context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]);
+  Tensor empty;
+  Tensor &d_bias_h =
+    !disable_bias && integrate_bias
+      ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h])
+      : empty;
+  Tensor &d_bias_ih =
+    !disable_bias && !integrate_bias
+      ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih])
+      : empty;
+  Tensor &d_bias_hh =
+    !disable_bias && !integrate_bias
+      ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh])
+      : empty;
+
+  Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
+  hs.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_hidden_state;
+  if (!timestep) {
+    prev_hidden_state = Tensor(batch_size, 1, 1, unit);
+    prev_hidden_state.setZero();
+  } else {
+    prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
+    prev_hidden_state.reshape({batch_size, 1, 1, unit});
+  }
+
+  Tensor &d_hs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]);
+  d_hs.reshape({max_timestep, 1, batch_size, unit});
+  Tensor d_prev_hidden_state;
+  if (!timestep) {
+    d_prev_hidden_state = Tensor(batch_size, 1, 1, unit);
+    d_prev_hidden_state.setZero();
+  } else {
+    d_prev_hidden_state = d_hs.getBatchSlice(timestep - 1, 1);
+    d_prev_hidden_state.reshape({batch_size, 1, 1, unit});
+  }
+  Tensor d_hidden_state = d_hs.getBatchSlice(timestep, 1);
+  d_hidden_state.reshape({batch_size, 1, 1, unit});
+
+  Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
+  cs.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_cell_state;
+  if (!timestep) {
+    prev_cell_state = Tensor(batch_size, 1, 1, unit);
+    prev_cell_state.setZero();
+  } else {
+    prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
+    prev_cell_state.reshape({batch_size, 1, 1, unit});
+  }
+  Tensor cell_state = cs.getBatchSlice(timestep, 1);
+  cell_state.reshape({batch_size, 1, 1, unit});
 
-  Tensor &cell_state_derivative =
-    context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]);
-  cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_cell_state_derivative =
-    cell_state_derivative.getBatchSlice(timestep, 1);
-  next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+  Tensor &d_cs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]);
+  d_cs.reshape({max_timestep, 1, batch_size, unit});
+  Tensor d_prev_cell_state;
+  if (!timestep) {
+    d_prev_cell_state = Tensor(batch_size, 1, 1, unit);
+    d_prev_cell_state.setZero();
+  } else {
+    d_prev_cell_state = d_cs.getBatchSlice(timestep - 1, 1);
+    d_prev_cell_state.reshape({batch_size, 1, 1, unit});
+  }
+  Tensor d_cell_state = d_cs.getBatchSlice(timestep, 1);
+  d_cell_state.reshape({batch_size, 1, 1, unit});
+
+  Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
+  Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
+
+  const Tensor &lstm_cell_state =
+    context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
+  Tensor &d_lstm_cell_state =
+    context.getTensorGrad(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
 
   if (timestep + 1 == max_timestep) {
-    Tensor &djdweight_ih =
-      context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]);
-    Tensor &djdweight_hh =
-      context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]);
-    djdweight_ih.setZero();
-    djdweight_hh.setZero();
+    d_weight_ih.setZero();
+    d_weight_hh.setZero();
     if (!disable_bias) {
       if (integrate_bias) {
-        Tensor &djdbias_h =
-          context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h]);
-        djdbias_h.setZero();
+        d_bias_h.setZero();
       } else {
-        Tensor &djdbias_ih =
-          context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih]);
-        djdbias_ih.setZero();
-        Tensor &djdbias_hh =
-          context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh]);
-        djdbias_hh.setZero();
+        d_bias_ih.setZero();
+        d_bias_hh.setZero();
       }
     }
-
-    hidden_state_derivative.setZero();
-    cell_state_derivative.setZero();
+    d_hidden_state.setZero();
+    d_cell_state.setZero();
   }
 
-  next_hidden_state_derivative.add_i(incoming_derivative);
+  d_hidden_state.add_i(incoming_derivative);
 
-  Tensor prev_hidden_state_derivative;
-  Tensor prev_cell_state_derivative;
-  Tensor prev_hidden_state_derivative_residual;
-  Tensor prev_cell_state_derivative_residual;
+  Tensor d_prev_hidden_state_residual;
 
-  Tensor &hidden_state_zoneout_mask =
+  Tensor &hs_zoneout_mask =
     test
       ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
       : context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
-  hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_hidden_state_zoneout_mask =
-    hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
-  next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+  hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+  Tensor hidden_state_zoneout_mask = hs_zoneout_mask.getBatchSlice(timestep, 1);
+  hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
   Tensor prev_hidden_state_zoneout_mask;
   if (!test) {
     prev_hidden_state_zoneout_mask =
-      next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+      hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
   } else {
-    next_hidden_state_zoneout_mask.multiply(-1.0f,
-                                            prev_hidden_state_zoneout_mask);
+    hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask);
     prev_hidden_state_zoneout_mask.add_i(1.0f);
   }
 
-  if (timestep) {
-    prev_hidden_state_derivative =
-      hidden_state_derivative.getBatchSlice(timestep - 1, 1);
-    prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
-    next_hidden_state_derivative.multiply(
-      prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual);
-  }
-
-  Tensor &hidden_state_origin_derivative =
-    context.getTensorGrad(hidden_state_origin_idx);
-  hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_hidden_state_origin_derivative =
-    hidden_state_origin_derivative.getBatchSlice(timestep, 1);
-  next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+  d_hidden_state.multiply(prev_hidden_state_zoneout_mask,
+                          d_prev_hidden_state_residual);
+  d_hidden_state.multiply_i(hidden_state_zoneout_mask);
 
-  next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask,
-                                        next_hidden_state_origin_derivative);
+  Tensor d_prev_cell_state_residual;
 
-  Tensor &cell_state_zoneout_mask =
+  Tensor &cs_zoneout_mask =
     test
       ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
       : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
-  cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_cell_state_zoneout_mask =
-    cell_state_zoneout_mask.getBatchSlice(timestep, 1);
-  next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+  cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+  Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
+  cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
   Tensor prev_cell_state_zoneout_mask;
   if (!test) {
     prev_cell_state_zoneout_mask =
-      next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+      cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
   } else {
-    next_cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
+    cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
     prev_cell_state_zoneout_mask.add_i(1.0f);
   }
 
-  if (timestep) {
-    prev_cell_state_derivative =
-      cell_state_derivative.getBatchSlice(timestep - 1, 1);
-    prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
-    next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask,
-                                        prev_cell_state_derivative_residual);
-  }
+  d_cell_state.multiply(prev_cell_state_zoneout_mask,
+                        d_prev_cell_state_residual);
+  d_cell_state.multiply(cell_state_zoneout_mask, d_lstm_cell_state);
 
-  Tensor &cell_state_origin_derivative =
-    context.getTensorGrad(cell_state_origin_idx);
-  cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
-  Tensor next_cell_state_origin_derivative =
-    cell_state_origin_derivative.getBatchSlice(timestep, 1);
-  next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit});
-
-  next_cell_state_derivative.multiply(next_cell_state_zoneout_mask,
-                                      next_cell_state_origin_derivative);
-
-  init_lstm_context::fillWeights(
-    weights, context, true,
-    getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep,
-    timestep, test);
-  init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx),
-                                max_timestep, timestep);
-  init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx),
-                                 max_timestep, timestep);
-  init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
-                                 max_timestep, timestep);
-  RunLayerContext core_context(context.getName(), context.getTrainable(),
-                               context.getLoss(), context.executeInPlace(),
-                               init_lstm_context::getWeights(weights),
-                               init_lstm_context::getInputs(inputs),
-                               init_lstm_context::getOutputs(outputs),
-                               init_lstm_context::getTensors(tensors));
-  lstmcellcorelayer.calcGradient(core_context);
-
-  if (timestep) {
-    prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual);
-    prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual);
-  }
+  lstmcell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
+                        acti_func, recurrent_acti_func, input,
+                        prev_hidden_state, d_prev_hidden_state, prev_cell_state,
+                        d_prev_cell_state, d_hidden_state, lstm_cell_state,
+                        d_lstm_cell_state, d_weight_ih, weight_hh, d_weight_hh,
+                        d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+
+  d_prev_hidden_state.add_i(d_prev_hidden_state_residual);
+  d_prev_cell_state.add_i(d_prev_cell_state_residual);
 }
 
 void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
                                     unsigned int batch) {
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+
   context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state],
                        max_timestep * batch);
   context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state],
                        max_timestep * batch);
-  context.updateTensor(hidden_state_origin_idx, max_timestep * batch);
-  context.updateTensor(cell_state_origin_idx, max_timestep * batch);
-  context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], max_timestep * batch);
+  context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch);
+  context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch);
 
   context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
                        max_timestep * batch);
index 3895b83..515f55e 100644 (file)
@@ -169,34 +169,38 @@ public:
 private:
   static constexpr unsigned int NUM_GATE = 4;
 
-  LSTMCellCoreLayer lstmcellcorelayer;
-
   /**
    * Unit: number of output neurons
+   * IntegrateBias: integrate bias_ih, bias_hh to bias_h
+   * HiddenStateActivation: activation type for hidden state. default is tanh
+   * RecurrentActivation: activation type for recurrent. default is sigmoid
    * HiddenStateZoneOutRate: zoneout rate for hidden_state
    * CellStateZoneOutRate: zoneout rate for cell_state
-   * IntegrateBias: integrate bias_ih, bias_hh to bias_h
    * Test: property for test mode
    * MaxTimestep: maximum timestep for zoneout lstmcell
    * TimeStep: timestep for which lstm should operate
    *
    * */
-  std::tuple<props::Unit, HiddenStateZoneOutRate, CellStateZoneOutRate,
-             props::IntegrateBias, Test, props::MaxTimestep, props::Timestep>
+  std::tuple<props::Unit, props::IntegrateBias, props::HiddenStateActivation,
+             props::RecurrentActivation, HiddenStateZoneOutRate,
+             CellStateZoneOutRate, Test, props::MaxTimestep, props::Timestep>
     zoneout_lstmcell_props;
-  std::array<unsigned int, 10> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 11> wt_idx; /**< indices of the weights */
+
+  /**
+   * @brief     activation function for h_t : default is tanh
+   */
+  ActiFunc acti_func;
+
+  /**
+   * @brief     activation function for recurrent : default is sigmoid
+   */
+  ActiFunc recurrent_acti_func;
 
   /**
    * @brief     Protect overflow
    */
   float epsilon;
-
-  // These weights, inputs, outputs, tensors are all for the lstm_core
-  // Todo: remove this
-  std::vector<Weight> weights;
-  std::vector<Var_Grad> inputs;
-  std::vector<Var_Grad> outputs;
-  std::vector<Var_Grad> tensors;
 };
 } // namespace nntrainer