[lstmcell] support multi in/out
authorhyeonseok lee <hs89.lee@samsung.com>
Sat, 18 Dec 2021 20:33:03 +0000 (05:33 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 30 Dec 2021 10:05:55 +0000 (19:05 +0900)
 - Refactoring lstmcell layer to support multi in/out (3 input / 2 output)
 - Regenerate lstmcell testcase for multi in/out

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/layers/lstmcell.cpp
nntrainer/layers/lstmcell.h
packaging/unittest_layers_v2.tar.gz
packaging/unittest_models_v2.tar.gz
test/input_gen/genLayerTests.py
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer.py
test/unittest/layers/unittest_layers_lstmcell.cpp
test/unittest/models/unittest_models_recurrent.cpp

index d750100..113d999 100644 (file)
@@ -22,7 +22,6 @@
 #include <input_layer.h>
 #include <layer_node.h>
 #include <lstm.h>
-#include <lstmcell.h>
 #include <nntrainer_error.h>
 #include <node_exporter.h>
 #include <recurrent_realizer.h>
@@ -178,7 +177,6 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
   auto is_recurrent_type = [](LayerNode *node) {
     return node->getType() == RNNCellLayer::type ||
            node->getType() == LSTMLayer::type ||
-           node->getType() == LSTMCellLayer::type ||
            node->getType() == ZoneoutLSTMCellLayer::type ||
            node->getType() == GRUCellLayer::type;
   };
index c8a01e8..69655f7 100644 (file)
 
 namespace nntrainer {
 
-static constexpr size_t SINGLE_INOUT_IDX = 0;
-
 enum LSTMCellParams {
   weight_ih,
   weight_hh,
   bias_h,
   bias_ih,
   bias_hh,
-  hidden_state,
-  cell_state,
   ifgo,
   dropout_mask
 };
@@ -39,7 +35,7 @@ LSTMCellLayer::LSTMCellLayer() :
   lstmcell_props(props::Unit(), props::IntegrateBias(),
                  props::HiddenStateActivation() = ActivationType::ACT_TANH,
                  props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
-                 props::DropOutRate(), props::MaxTimestep(), props::Timestep()),
+                 props::DropOutRate()),
   acti_func(ActivationType::ACT_NONE, true),
   recurrent_acti_func(ActivationType::ACT_NONE, true),
   epsilon(1e-3) {
@@ -69,30 +65,51 @@ void LSTMCellLayer::finalize(InitLayerContext &context) {
   const ActivationType recurrent_activation_type =
     std::get<props::RecurrentActivation>(lstmcell_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(lstmcell_props).get();
 
-  if (context.getNumInputs() != 1)
-    throw std::invalid_argument("LSTMCell layer takes only one input");
-  if (std::get<props::MaxTimestep>(lstmcell_props).empty())
-    throw std::invalid_argument(
-      "Number of unroll steps(max timestep) must be provided to LSTM cell");
-  if (std::get<props::Timestep>(lstmcell_props).empty())
+  if (context.getNumInputs() != 3) {
     throw std::invalid_argument(
-      "Current Timestep must be provided to LSTM cell");
+      "Number of input is not 3. LSTMCell layer should takes 3 inputs");
+  }
 
   // input_dim = [ batch_size, 1, 1, feature_size ]
-  const TensorDim &input_dim = context.getInputDimensions()[0];
-  if (input_dim.channel() != 1 || input_dim.height() != 1)
+  const TensorDim &input_dim = context.getInputDimensions()[INOUT_INDEX::INPUT];
+  if (input_dim.channel() != 1 || input_dim.height() != 1) {
     throw std::invalid_argument(
       "Input must be single time dimension for LSTMCell (shape should be "
       "[batch_size, 1, 1, feature_size])");
+  }
+  // input_hidden_state_dim = [ batch, 1, 1, unit ]
+  const TensorDim &input_hidden_state_dim =
+    context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
+  if (input_hidden_state_dim.channel() != 1 ||
+      input_hidden_state_dim.height() != 1) {
+    throw std::invalid_argument("Input hidden state's dimension should be "
+                                "[batch, 1, 1, unit] for LSTMCell");
+  }
+  // input_cell_state_dim = [ batch, 1, 1, unit ]
+  const TensorDim &input_cell_state_dim =
+    context.getInputDimensions()[INOUT_INDEX::INPUT_CELL_STATE];
+  if (input_cell_state_dim.channel() != 1 ||
+      input_cell_state_dim.height() != 1) {
+    throw std::invalid_argument("Input cell state's dimension should be "
+                                "[batch, 1, 1, unit] for LSTMCell");
+  }
   const unsigned int batch_size = input_dim.batch();
   const unsigned int feature_size = input_dim.width();
 
-  // output_dim = [ batch_size, 1, 1, unit ]
-  const TensorDim output_dim(batch_size, 1, 1, unit);
-  context.setOutputDimensions({output_dim});
+  // output_hidden_state_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim output_hidden_state_dim = input_hidden_state_dim;
+  // output_cell_state_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim output_cell_state_dim = input_cell_state_dim;
+
+  std::vector<VarGradSpecV2> out_specs;
+  out_specs.push_back(
+    InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
+                              TensorLifespan::FORWARD_DERIV_LIFESPAN));
+  out_specs.push_back(
+    InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
+                              TensorLifespan::FORWARD_DERIV_LIFESPAN));
+  context.requestOutputs(std::move(out_specs));
 
   // weight_initializer can be set seperately. weight_ih initializer,
   // weight_hh initializer kernel initializer & recurrent_initializer in keras
@@ -134,21 +151,6 @@ void LSTMCellLayer::finalize(InitLayerContext &context) {
     }
   }
 
-  /**
-   * TODO: hidden_state is only used from the previous timestep. Once it is
-   * supported as input, no need to cache the hidden_state itself
-   */
-  /** hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
-  const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
-  wt_idx[LSTMCellParams::hidden_state] = context.requestTensor(
-    hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
-  /** cell_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
-  const TensorDim cell_state_dim(max_timestep * batch_size, 1, 1, unit);
-  wt_idx[LSTMCellParams::cell_state] = context.requestTensor(
-    cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
-
   /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
   const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
   wt_idx[LSTMCellParams::ifgo] =
@@ -187,12 +189,15 @@ void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
   const bool integrate_bias =
     std::get<props::IntegrateBias>(lstmcell_props).get();
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(lstmcell_props).get();
-  const unsigned int timestep = std::get<props::Timestep>(lstmcell_props).get();
 
-  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+  const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+  const Tensor &prev_hidden_state =
+    context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  const Tensor &prev_cell_state =
+    context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
+  Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
+  Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
+
   const unsigned int batch_size = input.getDim().batch();
 
   const Tensor &weight_ih =
@@ -200,41 +205,15 @@ void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
   const Tensor &weight_hh =
     context.getWeight(wt_idx[LSTMCellParams::weight_hh]);
   Tensor empty;
-  Tensor &bias_h = !disable_bias && integrate_bias
-                     ? context.getWeight(wt_idx[LSTMCellParams::bias_h])
-                     : empty;
-  Tensor &bias_ih = !disable_bias && !integrate_bias
-                      ? context.getWeight(wt_idx[LSTMCellParams::bias_ih])
-                      : empty;
-  Tensor &bias_hh = !disable_bias && !integrate_bias
-                      ? context.getWeight(wt_idx[LSTMCellParams::bias_hh])
-                      : empty;
-
-  Tensor &hs = context.getTensor(wt_idx[LSTMCellParams::hidden_state]);
-  hs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_hidden_state;
-  if (!timestep) {
-    prev_hidden_state = Tensor(batch_size, unit);
-    prev_hidden_state.setZero();
-  } else {
-    prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
-  }
-  prev_hidden_state.reshape({batch_size, 1, 1, unit});
-  Tensor hidden_state = hs.getBatchSlice(timestep, 1);
-  hidden_state.reshape({batch_size, 1, 1, unit});
-
-  Tensor &cs = context.getTensor(wt_idx[LSTMCellParams::cell_state]);
-  cs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_cell_state;
-  if (!timestep) {
-    prev_cell_state = Tensor(batch_size, unit);
-    prev_cell_state.setZero();
-  } else {
-    prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
-  }
-  prev_cell_state.reshape({batch_size, 1, 1, unit});
-  Tensor cell_state = cs.getBatchSlice(timestep, 1);
-  cell_state.reshape({batch_size, 1, 1, unit});
+  const Tensor &bias_h = !disable_bias && integrate_bias
+                           ? context.getWeight(wt_idx[LSTMCellParams::bias_h])
+                           : empty;
+  const Tensor &bias_ih = !disable_bias && !integrate_bias
+                            ? context.getWeight(wt_idx[LSTMCellParams::bias_ih])
+                            : empty;
+  const Tensor &bias_hh = !disable_bias && !integrate_bias
+                            ? context.getWeight(wt_idx[LSTMCellParams::bias_hh])
+                            : empty;
 
   Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
 
@@ -249,15 +228,14 @@ void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
     dropout_mask.dropout_mask(dropout_rate);
     hidden_state.multiply_i(dropout_mask);
   }
-
-  output.copyData(hidden_state);
 }
 
 void LSTMCellLayer::calcDerivative(RunLayerContext &context) {
   Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
   const Tensor &weight_ih =
     context.getWeight(wt_idx[LSTMCellParams::weight_ih]);
-  Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+  Tensor &outgoing_derivative =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT);
 
   lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
 }
@@ -270,13 +248,22 @@ void LSTMCellLayer::calcGradient(RunLayerContext &context) {
   const bool integrate_bias =
     std::get<props::IntegrateBias>(lstmcell_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(lstmcell_props);
-  const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
 
-  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  const Tensor &incoming_derivative =
-    context.getIncomingDerivative(SINGLE_INOUT_IDX);
+  const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+  const Tensor &prev_hidden_state =
+    context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  Tensor &d_prev_hidden_state =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  const Tensor &prev_cell_state =
+    context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
+  Tensor &d_prev_cell_state =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT_CELL_STATE);
+  const Tensor &d_hidden_state =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
+  const Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
+  const Tensor &d_cell_state =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT_CELL_STATE);
+
   unsigned int batch_size = input.getDim().batch();
 
   Tensor &d_weight_ih =
@@ -296,99 +283,48 @@ void LSTMCellLayer::calcGradient(RunLayerContext &context) {
                         ? context.getWeightGrad(wt_idx[LSTMCellParams::bias_hh])
                         : empty;
 
-  Tensor &hs = context.getTensor(wt_idx[LSTMCellParams::hidden_state]);
-  hs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_hidden_state;
-  if (!timestep) {
-    prev_hidden_state = Tensor(batch_size, unit);
-    prev_hidden_state.setZero();
-  } else {
-    prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
-  }
-  prev_hidden_state.reshape({batch_size, 1, 1, unit});
-
-  Tensor &d_hs = context.getTensorGrad(wt_idx[LSTMCellParams::hidden_state]);
-  d_hs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor d_prev_hidden_state;
-  if (!timestep) {
-    d_prev_hidden_state = Tensor(batch_size, unit);
-    d_prev_hidden_state.setZero();
-  } else {
-    d_prev_hidden_state = d_hs.getBatchSlice(timestep - 1, 1);
-  }
-  d_prev_hidden_state.reshape({batch_size, 1, 1, unit});
-  Tensor d_hidden_state = d_hs.getBatchSlice(timestep, 1);
-  d_hidden_state.reshape({batch_size, 1, 1, unit});
-
-  Tensor &cs = context.getTensor(wt_idx[LSTMCellParams::cell_state]);
-  cs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_cell_state;
-  if (!timestep) {
-    prev_cell_state = Tensor(batch_size, unit);
-    prev_cell_state.setZero();
-  } else {
-    prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
-  }
-  prev_cell_state.reshape({batch_size, 1, 1, unit});
-  Tensor cell_state = cs.getBatchSlice(timestep, 1);
-  cell_state.reshape({batch_size, 1, 1, unit});
-
-  Tensor &d_cs = context.getTensorGrad(wt_idx[LSTMCellParams::cell_state]);
-  d_cs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor d_prev_cell_state;
-  if (!timestep) {
-    d_prev_cell_state = Tensor(batch_size, unit);
-    d_prev_cell_state.setZero();
-  } else {
-    d_prev_cell_state = d_cs.getBatchSlice(timestep - 1, 1);
-  }
-  d_prev_cell_state.reshape({batch_size, 1, 1, unit});
-  Tensor d_cell_state = d_cs.getBatchSlice(timestep, 1);
-  d_cell_state.reshape({batch_size, 1, 1, unit});
-
   const Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
   Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
 
-  if (timestep + 1 == max_timestep) {
+  if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_ih])) {
     d_weight_ih.setZero();
+  }
+  if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_hh])) {
     d_weight_hh.setZero();
-    if (!disable_bias) {
-      if (integrate_bias) {
+  }
+  if (!disable_bias) {
+    if (integrate_bias) {
+      if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_h])) {
         d_bias_h.setZero();
-      } else {
+      }
+    } else {
+      if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_ih])) {
         d_bias_ih.setZero();
+      }
+      if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_hh])) {
         d_bias_hh.setZero();
       }
     }
-
-    d_hidden_state.setZero();
-    d_cell_state.setZero();
   }
 
+  Tensor d_hidden_state_masked;
   if (dropout_rate > epsilon) {
     Tensor &dropout_mask =
       context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
-    d_hidden_state.multiply_i(dropout_mask);
+    d_hidden_state.multiply(dropout_mask, d_hidden_state_masked);
   }
 
-  d_hidden_state.add_i(incoming_derivative);
-
-  lstmcell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
-                        acti_func, recurrent_acti_func, input,
-                        prev_hidden_state, d_prev_hidden_state, prev_cell_state,
-                        d_prev_cell_state, d_hidden_state, cell_state,
-                        d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
-                        d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+  lstmcell_calcGradient(
+    unit, batch_size, disable_bias, integrate_bias, acti_func,
+    recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
+    prev_cell_state, d_prev_cell_state,
+    dropout_rate > epsilon ? d_hidden_state_masked : d_hidden_state, cell_state,
+    d_cell_state, d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
+    d_bias_hh, ifgo, d_ifgo);
 }
 
 void LSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(lstmcell_props);
-  context.updateTensor(wt_idx[LSTMCellParams::hidden_state],
-                       max_timestep * batch);
-  context.updateTensor(wt_idx[LSTMCellParams::cell_state],
-                       max_timestep * batch);
   context.updateTensor(wt_idx[LSTMCellParams::ifgo], batch);
   if (dropout_rate > epsilon) {
     context.updateTensor(wt_idx[LSTMCellParams::dropout_mask], batch);
index 8b7d537..9c0225d 100644 (file)
@@ -56,6 +56,7 @@ public:
    * @copydoc Layer::calcGradient(RunLayerContext &context)
    */
   void calcGradient(RunLayerContext &context) override;
+
   /**
    * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
    */
@@ -86,6 +87,13 @@ public:
 
 private:
   static constexpr unsigned int NUM_GATE = 4;
+  enum INOUT_INDEX {
+    INPUT = 0,
+    INPUT_HIDDEN_STATE = 1,
+    INPUT_CELL_STATE = 2,
+    OUTPUT_HIDDEN_STATE = 0,
+    OUTPUT_CELL_STATE = 1
+  };
 
   /**
    * Unit: number of output neurons
@@ -93,15 +101,12 @@ private:
    * HiddenStateActivation: activation type for hidden state. default is tanh
    * RecurrentActivation: activation type for recurrent. default is sigmoid
    * DropOutRate: dropout rate
-   * MaxTimestep: maximum timestep for lstmcell
-   * TimeStep: timestep for which lstm should operate
    *
    * */
   std::tuple<props::Unit, props::IntegrateBias, props::HiddenStateActivation,
-             props::RecurrentActivation, props::DropOutRate, props::MaxTimestep,
-             props::Timestep>
+             props::RecurrentActivation, props::DropOutRate>
     lstmcell_props;
-  std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 7> wt_idx; /**< indices of the weights */
 
   /**
    * @brief     activation function for h_t : default is tanh
index dc5dcd3..f2263eb 100644 (file)
Binary files a/packaging/unittest_layers_v2.tar.gz and b/packaging/unittest_layers_v2.tar.gz differ
index bac0242..28b6020 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index fbf513d..3ff2976 100644 (file)
@@ -119,6 +119,13 @@ if __name__ == "__main__":
                          return_state=False)
     record_single(lstm, (3, 4, 7), "lstm_multi_step_seq_act")
 
+    unit, batch_size, unroll_for, feature_size, state_num = [5, 3, 1, 7, 2]
+    lstmcell = K.layers.LSTMCell(units=unit,
+                         activation="tanh",
+                         recurrent_activation="sigmoid",
+                         bias_initializer='glorot_uniform')
+    record_single(lstmcell, [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num)], "lstmcell_single_step", input_type='float')
+
     gru = K.layers.GRU(units=5, activation="tanh", 
                          recurrent_activation="sigmoid",
                          bias_initializer='GlorotUniform',
index 24f807c..c20f3d4 100644 (file)
@@ -83,6 +83,36 @@ class LSTMStacked(torch.nn.Module):
         loss = self.loss(ret, labels[0])
         return ret, loss
 
+class LSTMCellStacked(torch.nn.Module):
+    def __init__(self, unroll_for=2, num_lstmcell=1):
+        super().__init__()
+        self.input_size = self.hidden_size = 2
+        self.lstmcells = torch.nn.ModuleList(
+            [
+                torch.nn.LSTMCell(self.input_size, self.hidden_size, bias=True)
+                for _ in range(num_lstmcell)
+            ]
+        )
+        self.unroll_for = unroll_for
+        self.num_lstmcell = num_lstmcell
+        self.loss = torch.nn.MSELoss()
+
+    def forward(self, inputs, labels):
+        out = inputs[0]
+        states = inputs[1:]
+        hs = [states[2 * i] for i in range(self.num_lstmcell)]
+        cs = [states[2 * i + 1] for i in range(self.num_lstmcell)]
+        ret = []
+        for _ in range(self.unroll_for):
+            for i, (lstm, h, c) in enumerate(zip(self.lstmcells, hs, cs)):
+                hs[i], cs[i] = lstm(out, (h, c))
+                out = hs[i]
+            ret.append(out)
+
+        ret = torch.stack(ret, dim=1)
+        loss = self.loss(ret, labels[0])
+        return ret, loss
+
 class ZoneoutLSTMStacked(torch.nn.Module):
     def __init__(self, batch_size=3, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1, cell_state_zoneout_rate=1):
         super().__init__()
@@ -199,6 +229,24 @@ if __name__ == "__main__":
         name="lstm_stacked",
     )
 
+    unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 1, 2, 3, 2, 2, 2]
+    record_v2(
+        LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
+        label_dims=[(batch_size, unroll_for, unit)],
+        name="lstmcell_single",
+    )
+
+    unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 2, 2, 3, 2, 2, 2]
+    record_v2(
+        LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
+        label_dims=[(batch_size, unroll_for, unit)],
+        name="lstmcell_stacked",
+    )
+
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 0.0]
     record_v2(
         ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
index f526096..c859f3d 100644 (file)
@@ -214,6 +214,23 @@ class MultiOutLayer(IdentityTransLayer):
         return [layer(tf_output) for layer in self.stub_layers]
 
 ##
+# @brief Translayer for lstmcell layer
+class LSTMCellTransLayer(IdentityTransLayer):
+    def build(self, input_shape):
+        if not self.built:
+            self.tf_layer.build(input_shape[0])
+            super().build(input_shape[0])
+
+    ##
+    # @brief call function
+    # @param inputs input with nntrainer layout
+    def call(self, inputs):
+        input = inputs[0]
+        states = inputs[1:]
+        _, states = self.tf_layer.call(input, states)
+        return states
+
+##
 # @brief Translayer for gru layer
 class GRUTransLayer(IdentityTransLayer):
     def to_nntr_weights(self, tensorOrList):
@@ -239,6 +256,9 @@ def attach_trans_layer(layer):
     if isinstance(layer, CHANNEL_LAST_LAYERS):
         return ChannelLastTransLayer(layer)
 
+    if isinstance(layer, K.layers.LSTMCell):
+        return LSTMCellTransLayer(layer)
+
     if isinstance(layer, K.layers.GRU):
         return GRUTransLayer(layer)
 
index 730d5e2..2798124 100644 (file)
 
 auto semantic_lstmcell = LayerSemanticsParamType(
   nntrainer::createLayer<nntrainer::LSTMCellLayer>,
-  nntrainer::LSTMCellLayer::type, {"unit=1", "timestep=0", "max_timestep=1"}, 0,
-  false, 1);
+  nntrainer::LSTMCellLayer::type, {"unit=1"}, 0, false, 3);
 
 INSTANTIATE_TEST_CASE_P(LSTMCell, LayerSemantics,
                         ::testing::Values(semantic_lstmcell));
 
 auto lstmcell_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMCellLayer>,
-  {"unit=5", "timestep=0", "max_timestep=1", "integrate_bias=true"}, "3:1:1:7",
-  "lstm_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5,3:1:1:5",
+  "lstmcell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
 
 INSTANTIATE_TEST_CASE_P(LSTMCell, LayerGoldenTest,
                         ::testing::Values(lstmcell_single_step));
index bce03e1..d7ac3b2 100644 (file)
@@ -204,26 +204,36 @@ static std::unique_ptr<NeuralNetwork> makeSingleLSTMCell() {
 
   auto outer_graph = makeGraph({
     {"input", {"name=input", "input_shape=1:1:2"}},
+    {"input", {"name=input_hidden_state", "input_shape=1:1:2"}},
+    {"input", {"name=input_cell_state", "input_shape=1:1:2"}},
     /// here lstm_cells is being inserted
-    {"mse", {"name=loss", "input_layers=lstm_scope/a1"}},
+    {"mse", {"name=loss", "input_layers=lstmcell_scope/a1(0)"}},
   });
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
 
-  auto lstm = makeGraph({
-    {"lstmcell", {"name=a1", "unit=2", "integrate_bias=false"}},
+  auto lstmcell = makeGraph({
+    {"input", {"name=dummy_0", "input_shape=1"}},
+    {"input", {"name=dummy_1", "input_shape=1"}},
+    {"input", {"name=dummy_2", "input_shape=1"}},
+    {"lstmcell",
+     {"name=a1", "unit=2", "input_layers=dummy_0, dummy_1, dummy_2"}},
   });
 
-  nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a1"},
-                             ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a1",
-                               "recurrent_input=a1",
-                               "recurrent_output=a1",
-                             });
+  nn->addWithReferenceLayers(
+    lstmcell, "lstmcell_scope",
+    {"input", "input_hidden_state", "input_cell_state"},
+    {"a1(0)", "a1(1)", "a1(2)"}, {"a1"},
+    ml::train::ReferenceLayersType::RECURRENT,
+    {
+      "unroll_for=2",
+      "as_sequence=a1",
+      "recurrent_input=a1(0), a1(1), a1(2)",
+      "recurrent_output=a1(0), a1(0), a1(1)",
+    });
 
+  nn->setProperty({"input_layers=input, input_hidden_state, input_cell_state"});
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }
@@ -234,28 +244,49 @@ static std::unique_ptr<NeuralNetwork> makeStackedLSTMCell() {
 
   auto outer_graph = makeGraph({
     {"input", {"name=input", "input_shape=1:1:2"}},
+    {"input", {"name=a1_input_hidden_state", "input_shape=1:1:2"}},
+    {"input", {"name=a1_input_cell_state", "input_shape=1:1:2"}},
+    {"input", {"name=a2_input_hidden_state", "input_shape=1:1:2"}},
+    {"input", {"name=a2_input_cell_state", "input_shape=1:1:2"}},
     /// here lstm_cells is being inserted
-    {"mse", {"name=loss", "input_layers=lstm_scope/a2"}},
+    {"mse", {"name=loss", "input_layers=lstmcell_scope/a2(0)"}},
   });
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
 
-  auto lstm = makeGraph({
-    {"lstmcell", {"name=a1", "unit=2", "integrate_bias=false"}},
+  auto lstmcell = makeGraph({
+    {"input", {"name=dummy_0", "input_shape=1"}},
+    {"input", {"name=dummy_1", "input_shape=1"}},
+    {"input", {"name=dummy_2", "input_shape=1"}},
+    {"input", {"name=dummy_3", "input_shape=1"}},
+    {"input", {"name=dummy_4", "input_shape=1"}},
     {"lstmcell",
-     {"name=a2", "unit=2", "integrate_bias=false", "input_layers=a1"}},
+     {"name=a1", "unit=2", "input_layers=dummy_0, dummy_1, dummy_2"}},
+    {"lstmcell", {"name=a2", "unit=2", "input_layers=a1(0), dummy_3, dummy_4"}},
   });
 
-  nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a2"},
-                             ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a2",
-                               "recurrent_input=a1",
-                               "recurrent_output=a2",
-                             });
+  nn->addWithReferenceLayers(
+    lstmcell, "lstmcell_scope",
+    {
+      "input",
+      "a1_input_hidden_state",
+      "a1_input_cell_state",
+      "a2_input_hidden_state",
+      "a2_input_cell_state",
+    },
+    {"a1(0)", "a1(1)", "a1(2)", "a2(1)", "a2(2)"}, {"a2"},
+    ml::train::ReferenceLayersType::RECURRENT,
+    {
+      "unroll_for=2",
+      "as_sequence=a2",
+      "recurrent_input=a1(0), a1(1), a1(2), a2(1), a2(2)",
+      "recurrent_output=a2(0), a1(0), a1(1), a2(0), a2(1)",
+    });
 
+  nn->setProperty(
+    {"input_layers=input, a1_input_hidden_state, a1_input_cell_state, "
+     "a2_input_hidden_state, a2_input_cell_state"});
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }
@@ -287,7 +318,8 @@ static std::unique_ptr<NeuralNetwork> makeSingleZoneoutLSTMCell() {
 
   nn->addWithReferenceLayers(
     zoneout_lstm, "zoneout_lstm_scope",
-    {"input", "input_hidden_state", "input_cell_state"}, {"a1"}, {"a1"},
+    {"input", "input_hidden_state", "input_cell_state"},
+    {"a1(0)", "a1(1)", "a1(2)"}, {"a1"},
     ml::train::ReferenceLayersType::RECURRENT,
     {
       "unroll_for=2",
@@ -343,7 +375,8 @@ static std::unique_ptr<NeuralNetwork> makeStackedZoneoutLSTMCell() {
       "a2_input_hidden_state",
       "a2_input_cell_state",
     },
-    {"a1", "a2"}, {"a2"}, ml::train::ReferenceLayersType::RECURRENT,
+    {"a1(0)", "a1(1)", "a1(2)", "a2(1)", "a2(2)"}, {"a2"},
+    ml::train::ReferenceLayersType::RECURRENT,
     {
       "unroll_for=2",
       "as_sequence=a2",
@@ -497,56 +530,56 @@ INSTANTIATE_TEST_CASE_P(
     mkModelTc_V2(makeFC, "fc_unroll_stacked", ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeFCClipped, "fc_unroll_stacked_clipped",
                  ModelTestOption::COMPARE_V2),
-    mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::COMPARE_V2),
-    mkModelTc_V2(makeSingleLSTMCell, "lstm_single__1",
-                 ModelTestOption::COMPARE_V2),
-    mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::COMPARE_V2),
-    mkModelTc_V2(makeStackedLSTMCell, "lstm_stacked__1",
-                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeSingleLSTMCell, "lstmcell_single",
+                 ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeStackedLSTMCell, "lstmcell_stacked",
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_000",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_000",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_000",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_000",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_000",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_000",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_050",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_050",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_050",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_050",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_050",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_050",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_100",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_100",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_100",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_100",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_100",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_100",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleRNNCell, "rnncell_single__1",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedRNNCell, "rnncell_stacked__1",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleGRUCell, "grucell_single__1",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedGRUCell, "grucell_stacked__1",
-                 ModelTestOption::COMPARE_V2),
+                 ModelTestOption::ALL_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);