[zoneout lstmcell] support multi in/out
authorhyeonseok lee <hs89.lee@samsung.com>
Sat, 18 Dec 2021 22:20:33 +0000 (07:20 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 30 Dec 2021 10:05:55 +0000 (19:05 +0900)
 - Refactoring zoneout lstmcell layer to support multi in/out (3 input / 2output)
 - Regenerate zoneout testcase for multi in/out

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/lstmcell_core.cpp
nntrainer/layers/lstmcell_core.h
nntrainer/layers/zoneout_lstmcell.cpp
nntrainer/layers/zoneout_lstmcell.h
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelsRecurrent_v2.py
test/unittest/models/unittest_models_recurrent.cpp

index c6d31bd..2993f79 100644 (file)
@@ -70,7 +70,7 @@ void lstmcell_calcGradient(
   ActiFunc &recurrent_acti_func, const Tensor &input,
   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
   const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
-  Tensor &d_hidden_state, const Tensor &cell_state, Tensor &d_cell_state,
+  const Tensor &d_hidden_state, const Tensor &cell_state, const Tensor &d_cell_state,
   Tensor &d_weight_ih, const Tensor &weight_hh, Tensor &d_weight_hh,
   Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &ifgo,
   Tensor &d_ifgo) {
index cdb33ed..bf05390 100644 (file)
@@ -36,7 +36,7 @@ void lstmcell_calcGradient(
   ActiFunc &recurrent_acti_func, const Tensor &input,
   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
   const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
-  Tensor &d_hidden_state, const Tensor &cell_state, Tensor &d_cell_state,
+  const Tensor &d_hidden_state, const Tensor &cell_state, const Tensor &d_cell_state,
   Tensor &d_weight_ih, const Tensor &weight_hh, Tensor &d_weight_hh,
   Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &ifgo,
   Tensor &d_ifgo);
index 1f5e01a..6369334 100644 (file)
 
 namespace nntrainer {
 
-static constexpr size_t SINGLE_INOUT_IDX = 0;
-
 enum ZoneoutLSTMParams {
   weight_ih,
   weight_hh,
   bias_h,
   bias_ih,
   bias_hh,
-  hidden_state,
-  cell_state,
   ifgo,
   lstm_cell_state,
   hidden_state_zoneout_mask,
@@ -95,27 +91,51 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
 
-  if (context.getNumInputs() != 1)
-    throw std::invalid_argument("ZoneoutLSTMCellLayer takes only one input");
-  if (std::get<props::MaxTimestep>(zoneout_lstmcell_props).empty())
-    throw std::invalid_argument("Number of unroll steps(max timestep) must be "
-                                "provided to zoneout LSTM cells");
-  if (std::get<props::Timestep>(zoneout_lstmcell_props).empty())
+  if (context.getNumInputs() != 3) {
     throw std::invalid_argument(
-      "Current timestep must be provided to zoneout LSTM cell");
+      "Number of input is not 3. ZoneoutLSTMCellLayer should takes 3 inputs");
+  }
 
   // input_dim = [ batch_size, 1, 1, feature_size ]
-  const TensorDim &input_dim = context.getInputDimensions()[0];
+  const TensorDim &input_dim = context.getInputDimensions()[INOUT_INDEX::INPUT];
   if (input_dim.channel() != 1 || input_dim.height() != 1)
     throw std::invalid_argument("Input must be single time dimension for "
                                 "ZoneoutLSTMCell (shape should be "
                                 "[batch_size, 1, 1, feature_size])");
+  // input_hidden_state_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim &input_hidden_state_dim =
+    context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
+  if (input_hidden_state_dim.channel() != 1 ||
+      input_hidden_state_dim.height() != 1) {
+    throw std::invalid_argument(
+      "Input hidden state's dimension should be"
+      "[batch_size, 1, 1, unit] for zoneout LSTMcell");
+  }
+  // input_cell_state_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim &input_cell_state_dim =
+    context.getInputDimensions()[INOUT_INDEX::INPUT_CELL_STATE];
+  if (input_cell_state_dim.channel() != 1 ||
+      input_cell_state_dim.height() != 1) {
+    throw std::invalid_argument(
+      "Input cell state's dimension should be"
+      "[batch_size, 1, 1, unit] for zoneout LSTMcell");
+  }
   const unsigned int batch_size = input_dim.batch();
   const unsigned int feature_size = input_dim.width();
 
-  // output_dim = [ batch_size, 1, 1, unit ]
-  const TensorDim output_dim(batch_size, 1, 1, unit);
-  context.setOutputDimensions({output_dim});
+  // output_hidden_state_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim output_hidden_state_dim = input_hidden_state_dim;
+  // output_cell_state_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim output_cell_state_dim = input_cell_state_dim;
+
+  std::vector<VarGradSpecV2> out_specs;
+  out_specs.push_back(
+    InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
+                              TensorLifespan::FORWARD_DERIV_LIFESPAN));
+  out_specs.push_back(
+    InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
+                              TensorLifespan::FORWARD_DERIV_LIFESPAN));
+  context.requestOutputs(std::move(out_specs));
 
   // weight_initializer can be set seperately. weight_ih initializer,
   // weight_hh initializer kernel initializer & recurrent_initializer in keras
@@ -157,21 +177,6 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
     }
   }
 
-  /**
-   * TODO: hidden_state is only used from the previous timestep. Once it is
-   * supported as input, no need to cache the hidden_state itself
-   */
-  /** hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
-  const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
-  wt_idx[ZoneoutLSTMParams::hidden_state] = context.requestTensor(
-    hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
-  /** cell_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
-  const TensorDim cell_state_dim(max_timestep * batch_size, 1, 1, unit);
-  wt_idx[ZoneoutLSTMParams::cell_state] = context.requestTensor(
-    cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
-
   /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
   const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
   wt_idx[ZoneoutLSTMParams::ifgo] =
@@ -244,8 +249,13 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
   const unsigned int timestep =
     std::get<props::Timestep>(zoneout_lstmcell_props).get();
 
-  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+  const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+  const Tensor &prev_hidden_state =
+    context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  const Tensor &prev_cell_state =
+    context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
+  Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
+  Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
 
   const unsigned int batch_size = input.getDim().batch();
 
@@ -254,41 +264,18 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
   const Tensor &weight_hh =
     context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
   Tensor empty;
-  Tensor &bias_h = !disable_bias && integrate_bias
-                     ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_h])
-                     : empty;
-  Tensor &bias_ih = !disable_bias && !integrate_bias
-                      ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_ih])
-                      : empty;
-  Tensor &bias_hh = !disable_bias && !integrate_bias
-                      ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_hh])
-                      : empty;
-
-  Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
-  hs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_hidden_state;
-  if (!timestep) {
-    prev_hidden_state = Tensor(batch_size, 1, 1, unit);
-    prev_hidden_state.setZero();
-  } else {
-    prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
-    prev_hidden_state.reshape({batch_size, 1, 1, unit});
-  }
-  Tensor hidden_state = hs.getBatchSlice(timestep, 1);
-  hidden_state.reshape({batch_size, 1, 1, unit});
-
-  Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
-  cs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_cell_state;
-  if (!timestep) {
-    prev_cell_state = Tensor(batch_size, 1, 1, unit);
-    prev_cell_state.setZero();
-  } else {
-    prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
-    prev_cell_state.reshape({batch_size, 1, 1, unit});
-  }
-  Tensor cell_state = cs.getBatchSlice(timestep, 1);
-  cell_state.reshape({batch_size, 1, 1, unit});
+  const Tensor &bias_h =
+    !disable_bias && integrate_bias
+      ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_h])
+      : empty;
+  const Tensor &bias_ih =
+    !disable_bias && !integrate_bias
+      ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_ih])
+      : empty;
+  const Tensor &bias_hh =
+    !disable_bias && !integrate_bias
+      ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_hh])
+      : empty;
 
   Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
 
@@ -343,15 +330,14 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
     prev_cell_state.multiply(prev_cell_state_zoneout_mask, cell_state, 1.0f);
   }
   // Todo: zoneout at inference
-
-  output.copyData(hidden_state);
 }
 
 void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
   Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
   const Tensor &weight_ih =
     context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
-  Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+  Tensor &outgoing_derivative =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT);
 
   lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
 }
@@ -373,9 +359,19 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
   const unsigned int timestep =
     std::get<props::Timestep>(zoneout_lstmcell_props).get();
 
-  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  const Tensor &incoming_derivative =
-    context.getIncomingDerivative(SINGLE_INOUT_IDX);
+  const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+  const Tensor &prev_hidden_state =
+    context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  Tensor &d_prev_hidden_state =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  const Tensor &prev_cell_state =
+    context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
+  Tensor &d_prev_cell_state =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT_CELL_STATE);
+  const Tensor &d_hidden_state =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
+  const Tensor &d_cell_state =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT_CELL_STATE);
 
   unsigned int batch_size = input.getDim().batch();
 
@@ -399,56 +395,6 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
       ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh])
       : empty;
 
-  Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
-  hs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_hidden_state;
-  if (!timestep) {
-    prev_hidden_state = Tensor(batch_size, 1, 1, unit);
-    prev_hidden_state.setZero();
-  } else {
-    prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
-    prev_hidden_state.reshape({batch_size, 1, 1, unit});
-  }
-
-  Tensor &d_hs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]);
-  d_hs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor d_prev_hidden_state;
-  if (!timestep) {
-    d_prev_hidden_state = Tensor(batch_size, 1, 1, unit);
-    d_prev_hidden_state.setZero();
-  } else {
-    d_prev_hidden_state = d_hs.getBatchSlice(timestep - 1, 1);
-    d_prev_hidden_state.reshape({batch_size, 1, 1, unit});
-  }
-  Tensor d_hidden_state = d_hs.getBatchSlice(timestep, 1);
-  d_hidden_state.reshape({batch_size, 1, 1, unit});
-
-  Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
-  cs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_cell_state;
-  if (!timestep) {
-    prev_cell_state = Tensor(batch_size, 1, 1, unit);
-    prev_cell_state.setZero();
-  } else {
-    prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
-    prev_cell_state.reshape({batch_size, 1, 1, unit});
-  }
-  Tensor cell_state = cs.getBatchSlice(timestep, 1);
-  cell_state.reshape({batch_size, 1, 1, unit});
-
-  Tensor &d_cs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]);
-  d_cs.reshape({max_timestep, 1, batch_size, unit});
-  Tensor d_prev_cell_state;
-  if (!timestep) {
-    d_prev_cell_state = Tensor(batch_size, 1, 1, unit);
-    d_prev_cell_state.setZero();
-  } else {
-    d_prev_cell_state = d_cs.getBatchSlice(timestep - 1, 1);
-    d_prev_cell_state.reshape({batch_size, 1, 1, unit});
-  }
-  Tensor d_cell_state = d_cs.getBatchSlice(timestep, 1);
-  d_cell_state.reshape({batch_size, 1, 1, unit});
-
   Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
   Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
 
@@ -457,25 +403,28 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
   Tensor &d_lstm_cell_state =
     context.getTensorGrad(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
 
-  if (timestep + 1 == max_timestep) {
+  if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::weight_ih])) {
     d_weight_ih.setZero();
+  }
+  if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::weight_hh])) {
     d_weight_hh.setZero();
-    if (!disable_bias) {
-      if (integrate_bias) {
+  }
+  if (!disable_bias) {
+    if (integrate_bias) {
+      if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_h])) {
         d_bias_h.setZero();
-      } else {
+      }
+    } else {
+      if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_ih])) {
         d_bias_ih.setZero();
+      }
+      if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_hh])) {
         d_bias_hh.setZero();
       }
     }
-    d_hidden_state.setZero();
-    d_cell_state.setZero();
   }
 
-  d_hidden_state.add_i(incoming_derivative);
-
   Tensor d_prev_hidden_state_residual;
-
   Tensor &hs_zoneout_mask =
     test
       ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
@@ -494,10 +443,10 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
 
   d_hidden_state.multiply(prev_hidden_state_zoneout_mask,
                           d_prev_hidden_state_residual);
-  d_hidden_state.multiply_i(hidden_state_zoneout_mask);
+  Tensor d_hidden_state_masked;
+  d_hidden_state.multiply(hidden_state_zoneout_mask, d_hidden_state_masked);
 
   Tensor d_prev_cell_state_residual;
-
   Tensor &cs_zoneout_mask =
     test
       ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
@@ -518,12 +467,12 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
                         d_prev_cell_state_residual);
   d_cell_state.multiply(cell_state_zoneout_mask, d_lstm_cell_state);
 
-  lstmcell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
-                        acti_func, recurrent_acti_func, input,
-                        prev_hidden_state, d_prev_hidden_state, prev_cell_state,
-                        d_prev_cell_state, d_hidden_state, lstm_cell_state,
-                        d_lstm_cell_state, d_weight_ih, weight_hh, d_weight_hh,
-                        d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+  lstmcell_calcGradient(
+    unit, batch_size, disable_bias, integrate_bias, acti_func,
+    recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
+    prev_cell_state, d_prev_cell_state, d_hidden_state_masked, lstm_cell_state,
+    d_lstm_cell_state, d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
+    d_bias_hh, ifgo, d_ifgo);
 
   d_prev_hidden_state.add_i(d_prev_hidden_state_residual);
   d_prev_cell_state.add_i(d_prev_cell_state_residual);
@@ -534,10 +483,6 @@ void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(zoneout_lstmcell_props);
 
-  context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state],
-                       max_timestep * batch);
-  context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state],
-                       max_timestep * batch);
   context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch);
   context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch);
 
index 515f55e..4ad9e36 100644 (file)
@@ -168,6 +168,13 @@ public:
 
 private:
   static constexpr unsigned int NUM_GATE = 4;
+  enum INOUT_INDEX {
+    INPUT = 0,
+    INPUT_HIDDEN_STATE = 1,
+    INPUT_CELL_STATE = 2,
+    OUTPUT_HIDDEN_STATE = 0,
+    OUTPUT_CELL_STATE = 1
+  };
 
   /**
    * Unit: number of output neurons
@@ -185,7 +192,7 @@ private:
              props::RecurrentActivation, HiddenStateZoneOutRate,
              CellStateZoneOutRate, Test, props::MaxTimestep, props::Timestep>
     zoneout_lstmcell_props;
-  std::array<unsigned int, 11> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
 
   /**
    * @brief     activation function for h_t : default is tanh
index 5e67378..bac0242 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index dd2d852..24f807c 100644 (file)
@@ -95,12 +95,14 @@ class ZoneoutLSTMStacked(torch.nn.Module):
             ]
         )
         self.unroll_for = unroll_for
+        self.num_lstm = num_lstm
         self.loss = torch.nn.MSELoss()
 
     def forward(self, inputs, labels):
-        hs = [torch.zeros_like(inputs[0]) for _ in self.zoneout_lstms]
-        cs = [torch.zeros_like(inputs[0]) for _ in self.zoneout_lstms]
         out = inputs[0]
+        states = inputs[1:]
+        hs = [states[2 * i] for i in range(self.num_lstm)]
+        cs = [states[2 * i + 1] for i in range(self.num_lstm)]
         ret = []
         for num_unroll in range(self.unroll_for):
             for i, (zoneout_lstm, h, c) in enumerate(zip(self.zoneout_lstms, hs, cs)):
@@ -197,147 +199,165 @@ if __name__ == "__main__":
         name="lstm_stacked",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_000_000",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_000_000",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.5, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_050_000",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.5, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_050_000",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 1.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_100_000",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 1.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_100_000",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.5),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_000_050",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.5),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_000_050",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.5, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.5),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_050_050",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.5, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.5),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_050_050",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 1.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.5),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_100_050",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 1.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.5),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_100_050",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=1.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_000_100",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=1.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_000_100",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.5, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=1.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_050_100",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.5, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=1.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_050_100",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 1.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=1.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_single_100_100",
     )
 
+    unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 1.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=1.0),
-        iteration=2,
-        input_dims=[(1, 2)],
-        label_dims=[(1, 2, 2)],
+        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="zoneout_lstm_stacked_100_100",
     )
 
index 3a10fa8..bce03e1 100644 (file)
@@ -266,29 +266,37 @@ static std::unique_ptr<NeuralNetwork> makeSingleZoneoutLSTMCell() {
 
   auto outer_graph = makeGraph({
     {"input", {"name=input", "input_shape=1:1:2"}},
+    {"input", {"name=input_hidden_state", "input_shape=1:1:2"}},
+    {"input", {"name=input_cell_state", "input_shape=1:1:2"}},
     /// here zoneout_lstm_cell is being inserted
-    {"mse", {"name=loss", "input_layers=zoneout_lstm_scope/a1"}},
+    {"mse", {"name=loss", "input_layers=zoneout_lstm_scope/a1(0)"}},
   });
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
 
   auto zoneout_lstm = makeGraph({
+    {"input", {"name=dummy_0", "input_shape=1"}},
+    {"input", {"name=dummy_1", "input_shape=1"}},
+    {"input", {"name=dummy_2", "input_shape=1"}},
     {"zoneout_lstmcell",
      {"name=a1", "unit=2", "hidden_state_zoneout_rate=1.0",
-      "cell_state_zoneout_rate=1.0", "test=true", "integrate_bias=false"}},
+      "cell_state_zoneout_rate=1.0", "test=true",
+      "input_layers=dummy_0, dummy_1, dummy_2"}},
   });
 
-  nn->addWithReferenceLayers(zoneout_lstm, "zoneout_lstm_scope", {"input"},
-                             {"a1"}, {"a1"},
-                             ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a1",
-                               "recurrent_input=a1",
-                               "recurrent_output=a1",
-                             });
-
+  nn->addWithReferenceLayers(
+    zoneout_lstm, "zoneout_lstm_scope",
+    {"input", "input_hidden_state", "input_cell_state"}, {"a1"}, {"a1"},
+    ml::train::ReferenceLayersType::RECURRENT,
+    {
+      "unroll_for=2",
+      "as_sequence=a1",
+      "recurrent_input=a1(0), a1(1), a1(2)",
+      "recurrent_output=a1(0), a1(0), a1(1)",
+    });
+
+  nn->setProperty({"input_layers=input, input_hidden_state, input_cell_state"});
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }
@@ -299,33 +307,53 @@ static std::unique_ptr<NeuralNetwork> makeStackedZoneoutLSTMCell() {
 
   auto outer_graph = makeGraph({
     {"input", {"name=input", "input_shape=1:1:2"}},
+    {"input", {"name=a1_input_hidden_state", "input_shape=1:1:2"}},
+    {"input", {"name=a1_input_cell_state", "input_shape=1:1:2"}},
+    {"input", {"name=a2_input_hidden_state", "input_shape=1:1:2"}},
+    {"input", {"name=a2_input_cell_state", "input_shape=1:1:2"}},
     /// here zoneout_lstm_cell is being inserted
-    {"mse", {"name=loss", "input_layers=zoneout_lstm_scope/a2"}},
+    {"mse", {"name=loss", "input_layers=zoneout_lstm_scope/a2(0)"}},
   });
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
 
   auto zoneout_lstm = makeGraph({
+    {"input", {"name=dummy_0", "input_shape=1"}},
+    {"input", {"name=dummy_1", "input_shape=1"}},
+    {"input", {"name=dummy_2", "input_shape=1"}},
+    {"input", {"name=dummy_3", "input_shape=1"}},
+    {"input", {"name=dummy_4", "input_shape=1"}},
     {"zoneout_lstmcell",
      {"name=a1", "unit=2", "hidden_state_zoneout_rate=1.0",
-      "cell_state_zoneout_rate=1.0", "test=true", "integrate_bias=false"}},
+      "cell_state_zoneout_rate=1.0", "test=true",
+      "input_layers=dummy_0, dummy_1, dummy_2"}},
     {"zoneout_lstmcell",
      {"name=a2", "unit=2", "hidden_state_zoneout_rate=1.0",
-      "cell_state_zoneout_rate=1.0", "test=true", "integrate_bias=false",
-      "input_layers=a1"}},
+      "cell_state_zoneout_rate=1.0", "test=true",
+      "input_layers=a1(0), dummy_3, dummy_4"}},
   });
 
-  nn->addWithReferenceLayers(zoneout_lstm, "zoneout_lstm_scope", {"input"},
-                             {"a1"}, {"a2"},
-                             ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a2",
-                               "recurrent_input=a1",
-                               "recurrent_output=a2",
-                             });
-
+  nn->addWithReferenceLayers(
+    zoneout_lstm, "zoneout_lstm_scope",
+    {
+      "input",
+      "a1_input_hidden_state",
+      "a1_input_cell_state",
+      "a2_input_hidden_state",
+      "a2_input_cell_state",
+    },
+    {"a1", "a2"}, {"a2"}, ml::train::ReferenceLayersType::RECURRENT,
+    {
+      "unroll_for=2",
+      "as_sequence=a2",
+      "recurrent_input=a1(0), a1(1), a1(2), a2(1), a2(2)",
+      "recurrent_output=a2(0), a1(0), a1(1), a2(0), a2(1)",
+    });
+
+  nn->setProperty(
+    {"input_layers=input, a1_input_hidden_state, a1_input_cell_state, "
+     "a2_input_hidden_state, a2_input_cell_state"});
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }