[lstm] remove timestep property
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 12 Jan 2022 11:29:41 +0000 (20:29 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 7 Feb 2022 08:36:39 +0000 (17:36 +0900)
 - Remove timestep property from lstm layer.
   This will disable unrolling the lstm layer.
 - Adjust recurrent unittest to simple lstm unittest.

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/layers/lstm.cpp
nntrainer/layers/lstm.h
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer_v2.py
test/unittest/compiler/unittest_realizer.cpp
test/unittest/models/unittest_models_recurrent.cpp

index 6e56fd6..71f0f0f 100644 (file)
@@ -20,7 +20,6 @@
 #include <connection.h>
 #include <input_layer.h>
 #include <layer_node.h>
-#include <lstm.h>
 #include <nntrainer_error.h>
 #include <node_exporter.h>
 #include <recurrent_realizer.h>
@@ -186,7 +185,6 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
   /** @todo add an interface to check if a layer supports a property */
   auto is_recurrent_type = [](LayerNode *node) {
     return node->getType() == RNNCellLayer::type ||
-           node->getType() == LSTMLayer::type ||
            node->getType() == ZoneoutLSTMCellLayer::type;
   };
 
index 71a0bb7..6afca49 100644 (file)
@@ -40,7 +40,7 @@ LSTMLayer::LSTMLayer() :
              props::HiddenStateActivation() = ActivationType::ACT_TANH,
              props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
              props::ReturnSequences(), props::DropOutRate(),
-             props::MaxTimestep(), props::Timestep()),
+             props::MaxTimestep()),
   acti_func(ActivationType::ACT_NONE, true),
   recurrent_acti_func(ActivationType::ACT_NONE, true),
   epsilon(1e-3) {
@@ -141,18 +141,18 @@ void LSTMLayer::finalize(InitLayerContext &context) {
   const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit);
   wt_idx[LSTMParams::hidden_state] = context.requestTensor(
     hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
+    TensorLifespan::ITERATION_LIFESPAN);
   // cell_state_dim : [ batch_size, 1, max_timestep, unit ]
   const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit);
   wt_idx[LSTMParams::cell_state] = context.requestTensor(
     cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
+    TensorLifespan::ITERATION_LIFESPAN);
 
   // ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
   const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit);
   wt_idx[LSTMParams::ifgo] =
     context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
-                          TensorLifespan::ITERATION_LIFESPAN, false);
+                          TensorLifespan::ITERATION_LIFESPAN);
 
   if (dropout_rate > epsilon) {
     // dropout_mask_dim = [ batch, 1, time_iteration, unit ]
@@ -189,23 +189,11 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(lstm_props).get();
-  const props::Timestep timestep = std::get<props::Timestep>(lstm_props);
-
-  unsigned int start_timestep = 0;
-  unsigned int end_timestep = max_timestep;
-  if (!timestep.empty()) {
-    const unsigned int current_timestep = timestep.get();
-    if (current_timestep >= end_timestep) {
-      throw std::runtime_error("Timestep to run exceeds input dimensions");
-    }
-
-    start_timestep = current_timestep;
-    end_timestep = current_timestep + 1;
-  }
 
   const Tensor &inputs = context.getInput(SINGLE_INOUT_IDX);
-  const unsigned int batch_size = inputs.getDim().batch();
-  const unsigned int feature_size = inputs.getDim().width();
+  const TensorDim input_dim = inputs.getDim();
+  const unsigned int batch_size = input_dim.batch();
+  const unsigned int feature_size = input_dim.width();
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
 
   const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
@@ -225,17 +213,8 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor &cs = context.getTensor(wt_idx[LSTMParams::cell_state]);
   Tensor &ifgos = context.getTensor(wt_idx[LSTMParams::ifgo]);
 
-  if (!start_timestep) {
-    hs.setZero();
-    cs.setZero();
-  }
-
-  /**
-   * @note when the recurrent realization happens, different instances of lstm
-   * will share the weights, hidden state, cell and ifgo memory. However, they
-   * do not share the input, output and derivatives memory. The input/output
-   * will be contain a single timestep data only.
-   */
+  hs.setZero();
+  cs.setZero();
 
   for (unsigned int batch = 0; batch < batch_size; ++batch) {
     const Tensor input_batch = inputs.getBatchSlice(batch, 1);
@@ -243,7 +222,7 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
     Tensor cs_batch = cs.getBatchSlice(batch, 1);
     Tensor ifgo_batch = ifgos.getBatchSlice(batch, 1);
 
-    for (unsigned int t = start_timestep; t < end_timestep; ++t) {
+    for (unsigned int t = 0; t < max_timestep; ++t) {
       Tensor input;
       if (input_batch.height() != 1)
         input =
@@ -286,12 +265,12 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
     }
   }
 
-  if (start_timestep == 0 && end_timestep == max_timestep && return_sequences) {
+  if (return_sequences) {
     std::copy(hs.getData(), hs.getData() + hs.size(), output.getData());
   } else {
     for (unsigned int batch = 0; batch < batch_size; ++batch) {
       float *hidden_state_data =
-        hs.getAddress(batch * max_timestep * unit + (end_timestep - 1) * unit);
+        hs.getAddress(batch * max_timestep * unit + (max_timestep - 1) * unit);
       float *output_data = output.getAddress(batch * unit);
       std::copy(hidden_state_data, hidden_state_data + unit, output_data);
     }
@@ -299,63 +278,11 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
 }
 
 void LSTMLayer::calcDerivative(RunLayerContext &context) {
-  const unsigned int unit = std::get<props::Unit>(lstm_props).get();
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(lstm_props).get();
-  const props::Timestep timestep = std::get<props::Timestep>(lstm_props);
-
-  unsigned int start_timestep = 0;
-  unsigned int end_timestep = max_timestep;
-  if (!timestep.empty()) {
-    const unsigned int cur_timestep = timestep.get();
-    // Todo: replace end_timestep with input's time iteration
-    if (cur_timestep >= end_timestep) {
-      throw std::runtime_error("Timestep to run exceeds input dimensions");
-    }
-
-    start_timestep = cur_timestep;
-    end_timestep = cur_timestep + 1;
-  }
-  const unsigned int timestep_diff = end_timestep - start_timestep;
-
-  const TensorDim input_dim = context.getInput(SINGLE_INOUT_IDX).getDim();
-  const unsigned int batch_size = input_dim.batch();
-  const unsigned int feature_size = input_dim.width();
-
-  const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
-  const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+  const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
+  const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
 
-  if (start_timestep == 0 && end_timestep == max_timestep) {
-    /**
-     * this if is only for optimization purpose. The else calculates for
-     * this scenario as well.
-     */
-    lstmcell_calcDerivative(d_ifgos, weight_ih, outgoing_derivative);
-  } else {
-    for (unsigned int b = 0; b < batch_size; ++b) {
-      const Tensor d_ifgo_batch = d_ifgos.getBatchSlice(b, 1);
-      Tensor outgoing_derivative_batch =
-        outgoing_derivative.getBatchSlice(b, 1);
-      Tensor d_ifgo, outgoing_derivative_;
-
-      if (d_ifgo_batch.height() != 1) {
-        d_ifgo = d_ifgo_batch.getSharedDataTensor(
-          {timestep_diff, NUM_GATE * unit}, start_timestep * NUM_GATE * unit);
-      } else {
-        d_ifgo = d_ifgo_batch;
-      }
-
-      if (outgoing_derivative_batch.height() != 1) {
-        outgoing_derivative_ = outgoing_derivative_batch.getSharedDataTensor(
-          {timestep_diff, feature_size}, start_timestep * feature_size);
-      } else {
-        outgoing_derivative_ = outgoing_derivative_batch;
-      }
-
-      lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative_);
-    }
-  }
+  lstmcell_calcDerivative(d_ifgos, weight_ih, outgoing_derivative);
 }
 
 void LSTMLayer::calcGradient(RunLayerContext &context) {
@@ -369,18 +296,9 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(lstm_props).get();
-  const props::Timestep timestep = std::get<props::Timestep>(lstm_props);
 
   unsigned int start_timestep = max_timestep - 1;
   int end_timestep = -1;
-  if (!timestep.empty()) {
-    const unsigned int cur_timestep = timestep.get();
-    NNTR_THROW_IF(cur_timestep > start_timestep, std::runtime_error)
-      << "Timestep to run exceeds input dimension current timestep"
-      << cur_timestep << "start_timestep" << start_timestep;
-    start_timestep = cur_timestep;
-    end_timestep = cur_timestep - 1;
-  }
 
   const Tensor &inputs = context.getInput(SINGLE_INOUT_IDX);
   const Tensor &incoming_derivative =
@@ -411,24 +329,21 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
   Tensor &ifgos = context.getTensor(wt_idx[LSTMParams::ifgo]);
   Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
 
-  if (start_timestep + 1 == max_timestep) {
-    d_weight_ih.setZero();
-    d_weight_hh.setZero();
-    if (!disable_bias) {
-      if (integrate_bias) {
-        d_bias_h.setZero();
-      } else {
-        d_bias_ih.setZero();
-        d_bias_hh.setZero();
-      }
+  d_weight_ih.setZero();
+  d_weight_hh.setZero();
+  if (!disable_bias) {
+    if (integrate_bias) {
+      d_bias_h.setZero();
+    } else {
+      d_bias_ih.setZero();
+      d_bias_hh.setZero();
     }
-
-    d_cs.setZero();
-    d_hs.setZero();
   }
 
-  if (start_timestep == max_timestep - 1 && end_timestep == -1 &&
-      return_sequences) {
+  d_cs.setZero();
+  d_hs.setZero();
+
+  if (return_sequences) {
     std::copy(incoming_derivative.getData(),
               incoming_derivative.getData() + incoming_derivative.size(),
               d_hs.getData());
index 5667f7f..804ba0e 100644 (file)
@@ -107,13 +107,12 @@ private:
    * RecurrentActivation: activation type for recurrent. default is sigmoid
    * ReturnSequence: option for return sequence
    * DropOutRate: dropout rate
-   * MaxTimestep: maximum timestep for lstmcell
-   * TimeStep: timestep for which lstm should operate
+   * MaxTimestep: maximum timestep for lstm
    *
    * */
   std::tuple<props::Unit, props::IntegrateBias, props::HiddenStateActivation,
              props::RecurrentActivation, props::ReturnSequences,
-             props::DropOutRate, props::MaxTimestep, props::Timestep>
+             props::DropOutRate, props::MaxTimestep>
     lstm_props;
   std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
 
index f752d8b..4480c3f 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index 47e1fd8..6197a19 100644 (file)
@@ -56,32 +56,31 @@ class RNNCellStacked(torch.nn.Module):
         return ret, loss
 
 class LSTMStacked(torch.nn.Module):
-    def __init__(self, unroll_for=2, num_lstm=1):
+    def __init__(self, num_lstm=1):
         super().__init__()
         self.input_size = self.hidden_size = 2
+        self.num_lstm = num_lstm
         self.lstms = torch.nn.ModuleList(
             [
-                torch.nn.LSTMCell(self.input_size, self.hidden_size, bias=True)
+                torch.nn.LSTM(self.input_size, self.hidden_size, batch_first=True)
+                # torch.nn.LSTM(self.input_size, self.hidden_size, num_layers=num_lstm, batch_first=True)
                 for _ in range(num_lstm)
             ]
         )
-        self.unroll_for = unroll_for
         self.loss = torch.nn.MSELoss()
 
     def forward(self, inputs, labels):
-        hs = [torch.zeros_like(inputs[0]) for _ in self.lstms]
-        cs = [torch.zeros_like(inputs[0]) for _ in self.lstms]
         out = inputs[0]
-        ret = []
-        for _ in range(self.unroll_for):
-            for i, (lstm, h, c) in enumerate(zip(self.lstms, hs, cs)):
-                hs[i], cs[i] = lstm(out, (h, c))
-                out = hs[i]
-            ret.append(out)
+        states = inputs[1:]
+        # hs = [states[2 * i] for i in range(self.num_lstm)]
+        hs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)]
+        # cs = [states[2 * i + 1] for i in range(self.num_lstm)]
+        cs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)]
+        for i, (lstm, h, c) in enumerate(zip(self.lstms, hs, cs)):
+            out, (hs[i], cs[i]) = lstm(out, (h, c))
 
-        ret = torch.stack(ret, dim=1)
-        loss = self.loss(ret, labels[0])
-        return ret, loss
+        loss = self.loss(out, labels[0])
+        return out, loss
 
 class LSTMCellStacked(torch.nn.Module):
     def __init__(self, unroll_for=2, num_lstmcell=1):
@@ -89,7 +88,7 @@ class LSTMCellStacked(torch.nn.Module):
         self.input_size = self.hidden_size = 2
         self.lstmcells = torch.nn.ModuleList(
             [
-                torch.nn.LSTMCell(self.input_size, self.hidden_size, bias=True)
+                torch.nn.LSTMCell(self.input_size, self.hidden_size)
                 for _ in range(num_lstmcell)
             ]
         )
@@ -213,19 +212,23 @@ if __name__ == "__main__":
         name="rnncell_stacked",
     )
 
+    unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 1, 3, 2, 2, 2]
     record_v2(
-        LSTMStacked(unroll_for=2, num_lstm=1),
-        iteration=2,
-        input_dims=[(3, 2)],
-        label_dims=[(3, 2, 2)],
+        LSTMStacked(num_lstm=num_lstm),
+        iteration=iteration,
+        input_dims=[(batch_size, unroll_for, feature_size)],
+        # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="lstm_single",
     )
 
+    unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 2, 3, 2, 2, 2]
     record_v2(
-        LSTMStacked(unroll_for=2, num_lstm=2),
-        iteration=2,
-        input_dims=[(3, 2)],
-        label_dims=[(3, 2, 2)],
+        LSTMStacked(num_lstm=num_lstm),
+        iteration=iteration,
+        input_dims=[(batch_size, unroll_for, feature_size)],
+        # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="lstm_stacked",
     )
 
index 4f15325..9373d67 100644 (file)
@@ -70,7 +70,7 @@ def zoneout_translate(model):
     new_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3], hidden_state, cell_state]
     yield from new_params
 
-@register_for_((torch.nn.RNNCell, torch.nn.LSTMCell))
+@register_for_((torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell))
 def rnn_lstm_translate(model):
     params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
     # [hidden, input] -> [input, hidden]
index 967883d..43d74cc 100644 (file)
@@ -137,34 +137,31 @@ TEST(RecurrentRealizer, recurrent_input_is_sequence_p) {
 TEST(RecurrentRealizer, recurrent_return_sequence_single_p) {
   using C = Connection;
   RecurrentRealizer r({"unroll_for=3", "as_sequence=fc_out",
-                       "recurrent_input=lstm", "recurrent_output=fc_out"},
+                       "recurrent_input=lstmcell", "recurrent_output=fc_out"},
                       {C("source")}, {C("fc_out")});
 
   std::vector<LayerRepresentation> before = {
-    {"lstm", {"name=lstm", "input_layers=source"}},
-    {"fully_connected", {"name=fc_out", "input_layers=lstm"}}};
+    {"lstmcell", {"name=lstmcell", "input_layers=source"}},
+    {"fully_connected", {"name=fc_out", "input_layers=lstmcell"}}};
 
   std::vector<LayerRepresentation> expected = {
     /// t - 0
-    {"lstm",
-     {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0",
-      "shared_from=lstm/0"}},
+    {"lstmcell",
+     {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}},
     {"fully_connected",
-     {"name=fc_out/0", "input_layers=lstm/0", "shared_from=fc_out/0"}},
+     {"name=fc_out/0", "input_layers=lstmcell/0", "shared_from=fc_out/0"}},
 
     /// t - 1
-    {"lstm",
-     {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=1"}},
+    {"lstmcell",
+     {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}},
     {"fully_connected",
-     {"name=fc_out/1", "input_layers=lstm/1", "shared_from=fc_out/0"}},
+     {"name=fc_out/1", "input_layers=lstmcell/1", "shared_from=fc_out/0"}},
 
     /// t - 2
-    {"lstm",
-     {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=2"}},
+    {"lstmcell",
+     {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}},
     {"fully_connected",
-     {"name=fc_out/2", "input_layers=lstm/2", "shared_from=fc_out/0"}},
+     {"name=fc_out/2", "input_layers=lstmcell/2", "shared_from=fc_out/0"}},
 
     /// mapping
     {"concat",
@@ -181,53 +178,50 @@ TEST(RecurrentRealizer, recurrent_multi_inout_return_seq_p) {
     {
       "unroll_for=3",
       "as_sequence=fc_out",
-      "recurrent_input=lstm,add(2)",
+      "recurrent_input=lstmcell,add(2)",
       "recurrent_output=fc_out,split(1)",
     },
     {C("source"), C("source2"), C("source3")}, {C("fc_out")});
 
   /// @note for below graph,
-  /// 1. fc_out feds back to lstm
+  /// 1. fc_out feds back to lstmcell
   /// 2. ouput_dummy feds back to source2_dummy
   /// ========================================================
-  /// lstm        -------- addition - split ---- fc_out (to_lstm)
+  /// lstmcell        -------- addition - split ---- fc_out (to_lstmcell)
   /// source2_dummy   --/                  \----- (to addition 3)
   std::vector<LayerRepresentation> before = {
-    {"lstm", {"name=lstm", "input_layers=source"}},
-    {"addition", {"name=add", "input_layers=lstm,source2,source3"}},
+    {"lstmcell", {"name=lstmcell", "input_layers=source"}},
+    {"addition", {"name=add", "input_layers=lstmcell,source2,source3"}},
     {"split", {"name=split", "input_layers=add"}},
     {"fully_connected", {"name=fc_out", "input_layers=split(0)"}},
   };
 
   std::vector<LayerRepresentation> expected = {
     /// timestep 0
-    {"lstm",
-     {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0",
-      "shared_from=lstm/0"}},
+    {"lstmcell",
+     {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/0", "input_layers=lstm/0,source2,source3",
+     {"name=add/0", "input_layers=lstmcell/0,source2,source3",
       "shared_from=add/0"}},
     {"split", {"name=split/0", "input_layers=add/0", "shared_from=split/0"}},
     {"fully_connected",
      {"name=fc_out/0", "input_layers=split/0(0)", "shared_from=fc_out/0"}},
 
     /// timestep 1
-    {"lstm",
-     {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=1"}},
+    {"lstmcell",
+     {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/1", "input_layers=lstm/1,source2,split/0(1)",
+     {"name=add/1", "input_layers=lstmcell/1,source2,split/0(1)",
       "shared_from=add/0"}},
     {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}},
     {"fully_connected",
      {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}},
 
     /// timestep 2
-    {"lstm",
-     {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=2"}},
+    {"lstmcell",
+     {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/2", "input_layers=lstm/2,source2,split/1(1)",
+     {"name=add/2", "input_layers=lstmcell/2,source2,split/1(1)",
       "shared_from=add/0"}},
     {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}},
     {"fully_connected",
@@ -247,53 +241,50 @@ TEST(RecurrentRealizer, recurrent_multi_inout_using_connection_p) {
   RecurrentRealizer r(
     {
       "unroll_for=3",
-      "recurrent_input=lstm,add(2)",
+      "recurrent_input=lstmcell,add(2)",
       "recurrent_output=fc_out,split(1)",
     },
     {C("source"), C("source2"), C("source3")}, {C("fc_out")});
 
   /// @note for below graph,
-  /// 1. fc_out feds back to lstm
+  /// 1. fc_out feds back to lstmcell
   /// 2. ouput_dummy feds back to source2_dummy
   /// ========================================================
-  /// lstm        -------- addition - split ---- fc_out (to_lstm)
+  /// lstmcell        -------- addition - split ---- fc_out (to_lstmcell)
   /// source2_dummy   --/                  \----- (to addition 3)
   std::vector<LayerRepresentation> before = {
-    {"lstm", {"name=lstm", "input_layers=source"}},
-    {"addition", {"name=add", "input_layers=lstm,source2,source3"}},
+    {"lstmcell", {"name=lstmcell", "input_layers=source"}},
+    {"addition", {"name=add", "input_layers=lstmcell,source2,source3"}},
     {"split", {"name=split", "input_layers=add"}},
     {"fully_connected", {"name=fc_out", "input_layers=split(0)"}},
   };
 
   std::vector<LayerRepresentation> expected = {
     /// timestep 0
-    {"lstm",
-     {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0",
-      "shared_from=lstm/0"}},
+    {"lstmcell",
+     {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/0", "input_layers=lstm/0,source2,source3",
+     {"name=add/0", "input_layers=lstmcell/0,source2,source3",
       "shared_from=add/0"}},
     {"split", {"name=split/0", "input_layers=add/0", "shared_from=split/0"}},
     {"fully_connected",
      {"name=fc_out/0", "input_layers=split/0(0)", "shared_from=fc_out/0"}},
 
     /// timestep 1
-    {"lstm",
-     {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=1"}},
+    {"lstmcell",
+     {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/1", "input_layers=lstm/1,source2,split/0(1)",
+     {"name=add/1", "input_layers=lstmcell/1,source2,split/0(1)",
       "shared_from=add/0"}},
     {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}},
     {"fully_connected",
      {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}},
 
     /// timestep 2
-    {"lstm",
-     {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=2"}},
+    {"lstmcell",
+     {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/2", "input_layers=lstm/2,source2,split/1(1)",
+     {"name=add/2", "input_layers=lstmcell/2,source2,split/1(1)",
       "shared_from=add/0"}},
     {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}},
     {"fully_connected",
@@ -311,7 +302,7 @@ TEST(RecurrentRealizer, recurrent_multi_inout_multi_connection_end_p) {
   RecurrentRealizer r(
     {
       "unroll_for=3",
-      "recurrent_input=lstm,add(2)",
+      "recurrent_input=lstmcell,add(2)",
       "recurrent_output=fc_out,split(1)",
       "as_sequence=split(1)",
     },
@@ -326,47 +317,44 @@ TEST(RecurrentRealizer, recurrent_multi_inout_multi_connection_end_p) {
     });
 
   /// @note for below graph,
-  /// 1. fc_out feds back to lstm
+  /// 1. fc_out feds back to lstmcell
   /// 2. ouput_dummy feds back to source2_dummy
   /// ========================================================
-  /// lstm        -------- addition - split ---- fc_out (to_lstm)
+  /// lstmcell        -------- addition - split ---- fc_out (to_lstmcell)
   /// source2_dummy   --/                  \----- (to addition 3)
   std::vector<LayerRepresentation> before = {
-    {"lstm", {"name=lstm", "input_layers=source"}},
-    {"addition", {"name=add", "input_layers=lstm,source2,source3"}},
+    {"lstmcell", {"name=lstmcell", "input_layers=source"}},
+    {"addition", {"name=add", "input_layers=lstmcell,source2,source3"}},
     {"split", {"name=split", "input_layers=add"}},
     {"fully_connected", {"name=fc_out", "input_layers=split(0)"}},
   };
 
   std::vector<LayerRepresentation> expected = {
     /// timestep 0
-    {"lstm",
-     {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0",
-      "shared_from=lstm/0"}},
+    {"lstmcell",
+     {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/0", "input_layers=lstm/0,source2,source3",
+     {"name=add/0", "input_layers=lstmcell/0,source2,source3",
       "shared_from=add/0"}},
     {"split", {"name=split/0", "input_layers=add/0", "shared_from=split/0"}},
     {"fully_connected",
      {"name=fc_out/0", "input_layers=split/0(0)", "shared_from=fc_out/0"}},
 
     /// timestep 1
-    {"lstm",
-     {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=1"}},
+    {"lstmcell",
+     {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/1", "input_layers=lstm/1,source2,split/0(1)",
+     {"name=add/1", "input_layers=lstmcell/1,source2,split/0(1)",
       "shared_from=add/0"}},
     {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}},
     {"fully_connected",
      {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}},
 
     /// timestep 2
-    {"lstm",
-     {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
-      "max_timestep=3", "timestep=2"}},
+    {"lstmcell",
+     {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}},
     {"addition",
-     {"name=add/2", "input_layers=lstm/2,source2,split/1(1)",
+     {"name=add/2", "input_layers=lstmcell/2,source2,split/1(1)",
       "shared_from=add/0"}},
     {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}},
     {"fully_connected",
index 7859851..9aa3364 100644 (file)
@@ -142,27 +142,15 @@ static std::unique_ptr<NeuralNetwork> makeSingleLSTM() {
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
-    {"input", {"name=input", "input_shape=1:1:2"}},
-    /// here lstm is being inserted
-    {"mse", {"name=loss", "input_layers=lstm_scope/a1"}},
+    {"input", {"name=input", "input_shape=1:2:2"}},
+    {"lstm",
+     {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true"}},
+    {"mse", {"name=loss", "input_layers=a1"}},
   });
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
 
-  auto lstm = makeGraph({
-    {"lstm", {"name=a1", "unit=2", "integrate_bias=false"}},
-  });
-
-  nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a1"},
-                             ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a1",
-                               "recurrent_input=a1",
-                               "recurrent_output=a1",
-                             });
-
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }
@@ -172,28 +160,17 @@ static std::unique_ptr<NeuralNetwork> makeStackedLSTM() {
   nn->setProperty({"batch_size=3"});
 
   auto outer_graph = makeGraph({
-    {"input", {"name=input", "input_shape=1:1:2"}},
-    /// here lstm is being inserted
-    {"mse", {"name=loss", "input_layers=lstm_scope/a2"}},
+    {"input", {"name=input", "input_shape=1:2:2"}},
+    {"lstm",
+     {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true"}},
+    {"lstm",
+     {"name=a2", "unit=2", "integrate_bias=false", "return_sequences=true"}},
+    {"mse", {"name=loss"}},
   });
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
 
-  auto lstm = makeGraph({
-    {"lstm", {"name=a1", "unit=2", "integrate_bias=false"}},
-    {"lstm", {"name=a2", "unit=2", "integrate_bias=false", "input_layers=a1"}},
-  });
-
-  nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a2"},
-                             ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a2",
-                               "recurrent_input=a1",
-                               "recurrent_output=a2",
-                             });
-
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }