[grucell] enable multi inout
authorhyeonseok lee <hs89.lee@samsung.com>
Thu, 30 Dec 2021 04:54:28 +0000 (13:54 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 31 Dec 2021 11:58:47 +0000 (20:58 +0900)
 - Enable multi inout for grucell
 - Generate grucell layer/model unittest

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/layers/grucell.cpp
nntrainer/layers/grucell.h
packaging/unittest_layers_v2.tar.gz
packaging/unittest_models_v2.tar.gz
test/input_gen/genLayerTests.py
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer.py
test/unittest/layers/unittest_layers_grucell.cpp
test/unittest/models/unittest_models_recurrent.cpp

index d0215bc..6e56fd6 100644 (file)
@@ -18,7 +18,6 @@
 #include <base_properties.h>
 #include <common_properties.h>
 #include <connection.h>
-#include <grucell.h>
 #include <input_layer.h>
 #include <layer_node.h>
 #include <lstm.h>
@@ -188,8 +187,7 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
   auto is_recurrent_type = [](LayerNode *node) {
     return node->getType() == RNNCellLayer::type ||
            node->getType() == LSTMLayer::type ||
-           node->getType() == ZoneoutLSTMCellLayer::type ||
-           node->getType() == GRUCellLayer::type;
+           node->getType() == ZoneoutLSTMCellLayer::type;
   };
 
   if (is_recurrent_type(node)) {
index 1dc1a79..08e8be9 100644 (file)
 
 namespace nntrainer {
 
-static constexpr size_t SINGLE_INOUT_IDX = 0;
+void grucell_forwarding(const unsigned int unit, const unsigned int batch_size,
+                        const bool disable_bias, const bool integrate_bias,
+                        const bool reset_after, ActiFunc &acti_func,
+                        ActiFunc &recurrent_acti_func, const Tensor &input,
+                        const Tensor &prev_hidden_state, Tensor &hidden_state,
+                        const Tensor &weight_ih, const Tensor &weight_hh,
+                        const Tensor &bias_h, const Tensor &bias_ih,
+                        const Tensor &bias_hh, Tensor &zrg) {
+  input.dot(weight_ih, zrg);
+
+  Tensor update_reset_gate =
+    zrg.getSharedDataTensor({batch_size, 1, 1, 2 * unit}, 0, false);
+  Tensor memory_cell =
+    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
+
+  Tensor weight_hh_update_reset_gate;
+  Tensor weight_hh_memory_cell;
+  weight_hh_update_reset_gate.copy_with_stride(
+    weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
+  weight_hh_memory_cell.copy_with_stride(
+    weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
+
+  update_reset_gate.add_i_strided(
+    prev_hidden_state.dot(weight_hh_update_reset_gate));
+  if (!disable_bias) {
+    if (integrate_bias) {
+      const Tensor bias_h_update_reset_gate =
+        bias_h.getSharedDataTensor({2 * unit}, 0);
+      update_reset_gate.add_i(bias_h_update_reset_gate);
+    } else {
+      const Tensor bias_ih_update_reset_gate =
+        bias_ih.getSharedDataTensor({2 * unit}, 0);
+      update_reset_gate.add_i(bias_ih_update_reset_gate);
+      const Tensor bias_hh_update_reset_gate =
+        bias_hh.getSharedDataTensor({2 * unit}, 0);
+      update_reset_gate.add_i(bias_hh_update_reset_gate);
+    }
+  }
+
+  recurrent_acti_func.run_fn(update_reset_gate, update_reset_gate);
+
+  Tensor update_gate =
+    update_reset_gate.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor reset_gate = update_reset_gate.getSharedDataTensor(
+    {batch_size, 1, 1, unit}, unit, false);
+
+  Tensor temp;
+  if (reset_after) {
+    prev_hidden_state.dot(weight_hh_memory_cell, temp);
+    if (!disable_bias && !integrate_bias) {
+      const Tensor bias_hh_memory_cell =
+        bias_hh.getSharedDataTensor({unit}, 2 * unit);
+      temp.add_i(bias_hh_memory_cell);
+    }
+    temp.multiply_i_strided(reset_gate);
+    memory_cell.add_i_strided(temp);
+  } else {
+    reset_gate.multiply_strided(prev_hidden_state, temp);
+    memory_cell.add_i_strided(temp.dot(weight_hh_memory_cell));
+    if (!disable_bias && !integrate_bias) {
+      const Tensor bias_hh_memory_cell =
+        bias_hh.getSharedDataTensor({unit}, 2 * unit);
+      memory_cell.add_i(bias_hh_memory_cell);
+    }
+  }
+  if (!disable_bias) {
+    if (integrate_bias) {
+      const Tensor bias_h_memory_cell =
+        bias_h.getSharedDataTensor({unit}, 2 * unit);
+      memory_cell.add_i(bias_h_memory_cell);
+    } else {
+      const Tensor bias_ih_memory_cell =
+        bias_ih.getSharedDataTensor({unit}, 2 * unit);
+      memory_cell.add_i(bias_ih_memory_cell);
+    }
+  }
+
+  acti_func.run_fn(memory_cell, memory_cell);
+
+  update_gate.multiply_strided(prev_hidden_state, hidden_state);
+  temp = update_gate.multiply(-1.0).add(1.0);
+
+  hidden_state.add_i(memory_cell.multiply_strided(temp));
+}
+
+void grucell_calcGradient(
+  const unsigned int unit, const unsigned int batch_size,
+  const bool disable_bias, const bool integrate_bias, const bool reset_after,
+  ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
+  const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
+  const Tensor &d_hidden_state, Tensor &d_weight_ih, const Tensor &weight_hh,
+  Tensor &d_weight_hh, Tensor &d_bias_h, Tensor &d_bias_ih,
+  const Tensor &bias_hh, Tensor &d_bias_hh, const Tensor &zrg, Tensor &d_zrg) {
+  Tensor d_weight_hh_update_reset_gate =
+    d_weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false);
+  Tensor d_weight_hh_memory_cell =
+    d_weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false);
+
+  Tensor update_gate =
+    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor reset_gate =
+    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor memory_cell =
+    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
+
+  Tensor d_update_gate =
+    d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor d_reset_gate =
+    d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor d_memory_cell =
+    d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
+
+  d_hidden_state.multiply_strided(
+    update_gate, d_prev_hidden_state); // d_prev_hidden_state = d1
+  d_hidden_state.multiply_strided(prev_hidden_state,
+                                  d_update_gate); // d_update_gate = d2
+  d_update_gate.add_i_strided(d_hidden_state.multiply_strided(memory_cell),
+                              -1.0f); // d_update_gate = d5
+  update_gate.multiply(-1.0, d_memory_cell);
+  d_memory_cell.add_i(1.0);
+  d_memory_cell.multiply_i_strided(d_hidden_state); // d_memory_cell = d6
+
+  recurrent_acti_func.run_prime_fn(update_gate, d_update_gate,
+                                   d_update_gate); // d_update_gate = d7
+  acti_func.run_prime_fn(memory_cell, d_memory_cell,
+                         d_memory_cell); // d_memory_cell = d8
+
+  Tensor d_update_reset_gate = d_zrg.getSharedDataTensor(
+    {batch_size, 1, 1, 2 * unit}, 0, false); // d_update_gate+d_reset_gate
+
+  Tensor weight_hh_memory_cell;
+  weight_hh_memory_cell.copy_with_stride(
+    weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
+  Tensor weight_hh_update_reset_gate;
+  weight_hh_update_reset_gate.copy_with_stride(
+    weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
+
+  Tensor temp = Tensor(batch_size, 1, 1, unit);
+  Tensor d_memory_cell_contiguous;
+  d_memory_cell_contiguous.copy_with_stride(d_memory_cell);
+
+  if (reset_after) {
+    prev_hidden_state.dot(weight_hh_memory_cell, temp);
+    if (!disable_bias && !integrate_bias) {
+      const Tensor bias_hh_memory_cell =
+        bias_hh.getSharedDataTensor({unit}, 2 * unit);
+      temp.add_i(bias_hh_memory_cell);
+    }
+    d_memory_cell_contiguous.multiply_strided(
+      temp, d_reset_gate); // d_reset_gate = d15
+
+    // reset temp: d_memory_cell_contiguous * reset_gate for
+    // d_bias_hh_memory_cell, d_prev_hidden_state and d_weight_hh_memory_cell
+    d_memory_cell_contiguous.multiply_strided(reset_gate, temp);
+    if (!disable_bias && !integrate_bias) {
+      Tensor d_bias_hh_memory_cell =
+        d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
+      temp.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
+    }
+    temp.dot(weight_hh_memory_cell, d_prev_hidden_state, false, true,
+             1.0); // d_prev_hidden_state = d1 + d14
+    d_weight_hh_memory_cell.add_i_strided(
+      prev_hidden_state.dot(temp, true, false));
+  } else {
+    if (!disable_bias && !integrate_bias) {
+      Tensor d_bias_hh_memory_cell =
+        d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
+      d_memory_cell.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
+    }
+
+    d_memory_cell_contiguous.dot(weight_hh_memory_cell, temp, false, true);
+    temp.multiply_strided(prev_hidden_state, d_reset_gate);
+    temp.multiply_strided(reset_gate, d_prev_hidden_state, 1.0f);
+
+    // reset temp: reset_gate * prev_hidden_state for and
+    // d_weight_hh_memory_cell
+    reset_gate.multiply_strided(prev_hidden_state, temp);
+    d_weight_hh_memory_cell.add_i_strided(
+      temp.dot(d_memory_cell_contiguous, true, false));
+  }
+
+  recurrent_acti_func.run_prime_fn(reset_gate, d_reset_gate,
+                                   d_reset_gate); // d_reset_gate = d16
+
+  if (!disable_bias) {
+    if (integrate_bias) {
+      d_zrg.sum(0, d_bias_h, 1.0, 1.0);
+    } else {
+      d_zrg.sum(0, d_bias_ih, 1.0, 1.0);
+      Tensor d_bias_hh_update_reset_gate =
+        d_bias_hh.getSharedDataTensor({2 * unit}, 0);
+      d_bias_hh_update_reset_gate.add_i(
+        d_zrg.sum(0).getSharedDataTensor({2 * unit}, 0));
+    }
+  }
+
+  Tensor d_update_reset_gate_contiguous;
+  d_update_reset_gate_contiguous.copy_with_stride(d_update_reset_gate);
+  d_weight_hh_update_reset_gate.add_i_strided(
+    prev_hidden_state.dot(d_update_reset_gate_contiguous, true, false));
+  input.dot(d_zrg, d_weight_ih, true, false, 1.0f);
+  d_update_reset_gate_contiguous.dot(
+    weight_hh_update_reset_gate, d_prev_hidden_state, false, true,
+    1.0); // d_prev_hidden_state = d1 + d14 + d12 + d17
+}
 
 enum GRUCellParams {
   weight_ih,
@@ -47,7 +251,6 @@ enum GRUCellParams {
   bias_h,
   bias_ih,
   bias_hh,
-  hidden_state,
   zrg,
   dropout_mask
 };
@@ -56,11 +259,10 @@ enum GRUCellParams {
 // tensors
 GRUCellLayer::GRUCellLayer() :
   LayerImpl(),
-  grucell_props(props::Unit(),
+  grucell_props(props::Unit(), props::IntegrateBias(), props::ResetAfter(),
                 props::HiddenStateActivation() = ActivationType::ACT_TANH,
                 props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
-                props::DropOutRate(), props::IntegrateBias(),
-                props::ResetAfter(), props::MaxTimestep(), props::Timestep()),
+                props::DropOutRate()),
   acti_func(ActivationType::ACT_NONE, true),
   recurrent_acti_func(ActivationType::ACT_NONE, true),
   epsilon(1e-3) {
@@ -87,18 +289,26 @@ void GRUCellLayer::finalize(InitLayerContext &context) {
   const ActivationType recurrent_activation_type =
     std::get<props::RecurrentActivation>(grucell_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(grucell_props).get();
 
-  if (context.getNumInputs() != 1) {
-    throw std::invalid_argument("GRUCell layer takes only one input");
+  if (context.getNumInputs() != 2) {
+    throw std::invalid_argument(
+      "Number of input is not 2. GRUCell layer takes should takes 2 input");
   }
 
   // input_dim = [ batch_size, 1, 1, feature_size ]
   const TensorDim &input_dim = context.getInputDimensions()[0];
   if (input_dim.channel() != 1 && input_dim.height() != 1) {
     throw std::invalid_argument(
-      "Input must be single time dimension for GRUCell");
+      "Input must be single time dimension for GRUCell(shape should be "
+      "[batch_size, 1, 1, feature_size]");
+  }
+  // input_hidden_state_dim = [ batch_size, 1, 1, unit ]
+  const TensorDim &input_hidden_state_dim =
+    context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
+  if (input_hidden_state_dim.channel() != 1 ||
+      input_hidden_state_dim.height() != 1) {
+    throw std::invalid_argument("Input hidden state's dimension should be "
+                                "[batch, 1, 1, unit] for GRUCell");
   }
 
   const unsigned int batch_size = input_dim.batch();
@@ -148,12 +358,6 @@ void GRUCellLayer::finalize(InitLayerContext &context) {
     }
   }
 
-  // hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ]
-  TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
-  wt_idx[GRUCellParams::hidden_state] = context.requestTensor(
-    hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
-    TensorLifespan::ITERATION_LIFESPAN, false);
-
   // zrg_dim = [ batch_size, 1, 1, NUM_GATE * unit ]
   TensorDim zrg_dim(batch_size, 1, 1, NUM_GATE * unit);
   wt_idx[GRUCellParams::zrg] =
@@ -192,12 +396,12 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
     std::get<props::IntegrateBias>(grucell_props).get();
   const bool reset_after = std::get<props::ResetAfter>(grucell_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(grucell_props).get();
-  const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
 
-  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+  const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+  const Tensor &prev_hidden_state =
+    context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+
   const unsigned int batch_size = input.getDim().batch();
 
   const Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
@@ -213,96 +417,14 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
                             ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
                             : empty;
 
-  Tensor &hidden_states =
-    context.getTensor(wt_idx[GRUCellParams::hidden_state]);
-  hidden_states.reshape({max_timestep, 1, batch_size, unit});
-  Tensor prev_hidden_state;
-  if (!timestep) {
-    prev_hidden_state = Tensor(batch_size, unit);
-    prev_hidden_state.setZero();
-  } else {
-    prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
-  }
-  prev_hidden_state.reshape({batch_size, 1, 1, unit});
-  Tensor hidden_state = hidden_states.getBatchSlice(timestep, 1);
-  hidden_state.reshape({batch_size, 1, 1, unit});
-
   Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
 
-  input.dot(weight_ih, zrg);
+  Tensor hidden_state;
 
-  Tensor update_reset_gate =
-    zrg.getSharedDataTensor({batch_size, 1, 1, 2 * unit}, 0, false);
-  Tensor memory_cell =
-    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
-
-  Tensor weight_hh_update_reset_gate;
-  Tensor weight_hh_memory_cell;
-  weight_hh_update_reset_gate.copy_with_stride(
-    weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
-  weight_hh_memory_cell.copy_with_stride(
-    weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
-
-  update_reset_gate.add_i_strided(
-    prev_hidden_state.dot(weight_hh_update_reset_gate));
-  if (!disable_bias) {
-    if (integrate_bias) {
-      const Tensor bias_h_update_reset_gate =
-        bias_h.getSharedDataTensor({2 * unit}, 0);
-      update_reset_gate.add_i(bias_h_update_reset_gate);
-    } else {
-      const Tensor bias_ih_update_reset_gate =
-        bias_ih.getSharedDataTensor({2 * unit}, 0);
-      update_reset_gate.add_i(bias_ih_update_reset_gate);
-      const Tensor bias_hh_update_reset_gate =
-        bias_hh.getSharedDataTensor({2 * unit}, 0);
-      update_reset_gate.add_i(bias_hh_update_reset_gate);
-    }
-  }
-
-  recurrent_acti_func.run_fn(update_reset_gate, update_reset_gate);
-
-  Tensor update_gate =
-    update_reset_gate.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
-  Tensor reset_gate = update_reset_gate.getSharedDataTensor(
-    {batch_size, 1, 1, unit}, unit, false);
-
-  Tensor temp;
-  if (reset_after) {
-    prev_hidden_state.dot(weight_hh_memory_cell, temp);
-    if (!disable_bias && !integrate_bias) {
-      const Tensor bias_hh_memory_cell =
-        bias_hh.getSharedDataTensor({unit}, 2 * unit);
-      temp.add_i(bias_hh_memory_cell);
-    }
-    temp.multiply_i_strided(reset_gate);
-    memory_cell.add_i_strided(temp);
-  } else {
-    reset_gate.multiply_strided(prev_hidden_state, temp);
-    temp.dot(weight_hh_memory_cell, memory_cell, false, false, 1.0f);
-    if (!disable_bias && !integrate_bias) {
-      const Tensor bias_hh_memory_cell =
-        bias_hh.getSharedDataTensor({unit}, 2 * unit);
-      memory_cell.add_i(bias_hh_memory_cell);
-    }
-  }
-  if (!disable_bias) {
-    if (integrate_bias) {
-      const Tensor bias_h_memory_cell =
-        bias_h.getSharedDataTensor({unit}, 2 * unit);
-      memory_cell.add_i(bias_h_memory_cell);
-    } else {
-      const Tensor bias_ih_memory_cell =
-        bias_ih.getSharedDataTensor({unit}, 2 * unit);
-      memory_cell.add_i(bias_ih_memory_cell);
-    }
-  }
-
-  acti_func.run_fn(memory_cell, memory_cell);
-
-  update_gate.multiply_strided(prev_hidden_state, hidden_state);
-  temp = update_gate.multiply(-1.0).add(1.0);
-  hidden_state.add_i(memory_cell.multiply_strided(temp));
+  grucell_forwarding(unit, batch_size, disable_bias, integrate_bias,
+                     reset_after, acti_func, recurrent_acti_func, input,
+                     prev_hidden_state, hidden_state, weight_ih, weight_hh,
+                     bias_h, bias_ih, bias_hh, zrg);
 
   if (dropout_rate > epsilon && training) {
     Tensor mask = context.getTensor(wt_idx[GRUCellParams::dropout_mask]);
@@ -314,7 +436,8 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
 }
 
 void GRUCellLayer::calcDerivative(RunLayerContext &context) {
-  Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+  Tensor &outgoing_derivative =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT);
   const Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
   const Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
 
@@ -330,11 +453,15 @@ void GRUCellLayer::calcGradient(RunLayerContext &context) {
     std::get<props::IntegrateBias>(grucell_props).get();
   const bool reset_after = std::get<props::ResetAfter>(grucell_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(grucell_props).get();
-  const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
 
-  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+  const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+  const Tensor &prev_hidden_state =
+    context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  Tensor &d_prev_hidden_state =
+    context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
+
   const unsigned int batch_size = input.getDim().batch();
 
   Tensor &d_weight_ih = context.getWeightGrad(wt_idx[GRUCellParams::weight_ih]);
@@ -355,171 +482,48 @@ void GRUCellLayer::calcGradient(RunLayerContext &context) {
                         ? context.getWeightGrad(wt_idx[GRUCellParams::bias_hh])
                         : empty;
 
-  Tensor d_weight_hh_update_reset_gate =
-    d_weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false);
-  Tensor d_weight_hh_memory_cell =
-    d_weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false);
-  Tensor &hidden_states =
-    context.getTensor(wt_idx[GRUCellParams::hidden_state]);
-  hidden_states.reshape({max_timestep, 1, batch_size, unit});
-  Tensor &d_hidden_states =
-    context.getTensorGrad(wt_idx[GRUCellParams::hidden_state]);
-  const Tensor &incoming_derivative =
-    context.getIncomingDerivative(SINGLE_INOUT_IDX);
   const Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
   Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
 
-  d_hidden_states.reshape({max_timestep, 1, batch_size, unit});
-  Tensor d_hidden_state = d_hidden_states.getBatchSlice(timestep, 1);
-  d_hidden_state.reshape({batch_size, 1, 1, unit});
-  if (timestep + 1 == max_timestep) {
+  if (context.isGradientFirstAccess(wt_idx[GRUCellParams::weight_ih])) {
     d_weight_ih.setZero();
+  }
+  if (context.isGradientFirstAccess(wt_idx[GRUCellParams::weight_hh])) {
     d_weight_hh.setZero();
-    if (!disable_bias) {
-      if (integrate_bias) {
+  }
+  if (!disable_bias) {
+    if (integrate_bias) {
+      if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_h])) {
         d_bias_h.setZero();
-      } else {
+      }
+    } else {
+      if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_ih])) {
         d_bias_ih.setZero();
+      }
+      if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_hh])) {
         d_bias_hh.setZero();
       }
     }
-    d_hidden_state.setZero();
   }
 
-  d_hidden_state.add_i(incoming_derivative);
-
-  Tensor prev_hidden_state;
-  Tensor d_prev_hidden_state;
-  if (timestep) {
-    prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
-    d_prev_hidden_state = d_hidden_states.getBatchSlice(timestep - 1, 1);
-  } else {
-    prev_hidden_state = Tensor(batch_size, unit);
-    prev_hidden_state.setZero();
-    d_prev_hidden_state = Tensor(batch_size, unit);
-    d_prev_hidden_state.setZero();
-  }
-  prev_hidden_state.reshape({batch_size, 1, 1, unit});
-  d_prev_hidden_state.reshape({batch_size, 1, 1, unit});
+  Tensor d_hidden_state(batch_size, 1, 1, unit);
+  d_hidden_state.copyData(incoming_derivative);
 
   if (dropout_rate > epsilon) {
-    d_hidden_states.multiply_i(
+    d_hidden_state.multiply_i(
       context.getTensor(wt_idx[GRUCellParams::dropout_mask]));
   }
 
-  Tensor update_gate =
-    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
-  Tensor reset_gate =
-    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
-  Tensor memory_cell =
-    zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
-
-  Tensor d_update_gate =
-    d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
-  Tensor d_reset_gate =
-    d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
-  Tensor d_memory_cell =
-    d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
-
-  d_hidden_state.multiply_strided(
-    update_gate, d_prev_hidden_state); // d_prev_hidden_state = d1
-  d_hidden_state.multiply_strided(prev_hidden_state,
-                                  d_update_gate); // d_update_gate = d2
-  d_update_gate.add_i_strided(d_hidden_state.multiply_strided(memory_cell),
-                              -1.0f); // d_update_gate = d5
-  update_gate.multiply(-1.0, d_memory_cell);
-  d_memory_cell.add_i(1.0);
-  d_memory_cell.multiply_i_strided(d_hidden_state); // d_memory_cell = d6
-
-  recurrent_acti_func.run_prime_fn(update_gate, d_update_gate,
-                                   d_update_gate); // d_update_gate = d7
-  acti_func.run_prime_fn(memory_cell, d_memory_cell,
-                         d_memory_cell); // d_memory_cell = d8
-
-  Tensor d_update_reset_gate = d_zrg.getSharedDataTensor(
-    {batch_size, 1, 1, 2 * unit}, 0, false); // d_update_gate+d_reset_gate
-
-  Tensor weight_hh_memory_cell;
-  weight_hh_memory_cell.copy_with_stride(
-    weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
-  Tensor weight_hh_update_reset_gate;
-  weight_hh_update_reset_gate.copy_with_stride(
-    weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
-
-  Tensor temp = Tensor(batch_size, 1, 1, unit);
-  Tensor d_memory_cell_contiguous;
-  d_memory_cell_contiguous.copy_with_stride(d_memory_cell);
-
-  if (reset_after) {
-    prev_hidden_state.dot(weight_hh_memory_cell, temp);
-    if (!disable_bias && !integrate_bias) {
-      const Tensor bias_hh_memory_cell =
-        bias_hh.getSharedDataTensor({unit}, 2 * unit);
-      temp.add_i(bias_hh_memory_cell);
-    }
-    d_memory_cell_contiguous.multiply_strided(
-      temp, d_reset_gate); // d_reset_gate = d15
-
-    // reset temp: d_memory_cell_contiguous * reset_gate for
-    // d_bias_hh_memory_cell, d_prev_hidden_state and d_weight_hh_memory_cell
-    d_memory_cell_contiguous.multiply_strided(reset_gate, temp);
-    if (!disable_bias && !integrate_bias) {
-      Tensor d_bias_hh_memory_cell =
-        d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
-      temp.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
-    }
-    temp.dot(weight_hh_memory_cell, d_prev_hidden_state, false, true,
-             1.0); // d_prev_hidden_state = d1 + d14
-    d_weight_hh_memory_cell.add_i_strided(
-      prev_hidden_state.dot(temp, true, false));
-  } else {
-    if (!disable_bias && !integrate_bias) {
-      Tensor d_bias_hh_memory_cell =
-        d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
-      d_memory_cell.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
-    }
-
-    d_memory_cell_contiguous.dot(weight_hh_memory_cell, temp, false, true);
-    temp.multiply_strided(prev_hidden_state, d_reset_gate);
-    temp.multiply_strided(reset_gate, d_prev_hidden_state, 1.0f);
-
-    // reset temp: reset_gate * prev_hidden_state for and
-    // d_weight_hh_memory_cell
-    reset_gate.multiply_strided(prev_hidden_state, temp);
-    temp.dot(d_memory_cell_contiguous, d_weight_hh_memory_cell, true, false,
-             1.0f);
-  }
-
-  recurrent_acti_func.run_prime_fn(reset_gate, d_reset_gate,
-                                   d_reset_gate); // d_reset_gate = d16
-
-  if (!disable_bias) {
-    if (integrate_bias) {
-      d_zrg.sum(0, d_bias_h, 1.0, 1.0);
-    } else {
-      d_zrg.sum(0, d_bias_ih, 1.0, 1.0);
-      Tensor d_bias_hh_update_reset_gate =
-        d_bias_hh.getSharedDataTensor({2 * unit}, 0);
-      d_bias_hh_update_reset_gate.add_i(
-        d_zrg.sum(0).getSharedDataTensor({2 * unit}, 0));
-    }
-  }
-
-  Tensor d_update_reset_gate_contiguous;
-  d_update_reset_gate_contiguous.copy_with_stride(d_update_reset_gate);
-  d_weight_hh_update_reset_gate.add_i_strided(
-    prev_hidden_state.dot(d_update_reset_gate_contiguous, true, false));
-  input.dot(d_zrg, d_weight_ih, true, false, 1.0f);
-  d_update_reset_gate_contiguous.dot(
-    weight_hh_update_reset_gate, d_prev_hidden_state, false, true,
-    1.0); // d_prev_hidden_state = d1 + d14 + d12 + d17
+  grucell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
+                       reset_after, acti_func, recurrent_acti_func, input,
+                       prev_hidden_state, d_prev_hidden_state, d_hidden_state,
+                       d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
+                       bias_hh, d_bias_hh, zrg, d_zrg);
 }
 
 void GRUCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
   const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
-  unsigned int &max_timestep = std::get<props::MaxTimestep>(grucell_props);
-  context.updateTensor(wt_idx[GRUCellParams::hidden_state],
-                       max_timestep * batch);
+
   context.updateTensor(wt_idx[GRUCellParams::zrg], batch);
   if (dropout_rate > epsilon) {
     context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch);
index d02281c..db6f11e 100644 (file)
@@ -99,25 +99,27 @@ public:
 
 private:
   static constexpr unsigned int NUM_GATE = 3;
+  enum INOUT_INDEX {
+    INPUT = 0,
+    INPUT_HIDDEN_STATE = 1,
+    OUTPUT = 0,
+  };
 
   /**
    * Unit: number of output neurons
-   * HiddenStateActivation: activation type for hidden state. default is tanh
-   * RecurrentActivation: activation type for recurrent. default is sigmoid
-   * DropOutRate: dropout rate
    * IntegrateBias: integrate bias_ih, bias_hh to bias_h
    * ResetAfter: Whether apply reset gate before/after the matrix
    * multiplication. Apply reset gate after the mulplication if true.
-   * MaxTimeStep: Maximum timestep of gru
-   * TimeStep: timestep for which gru should operate
+   * HiddenStateActivation: activation type for hidden state. default is tanh
+   * RecurrentActivation: activation type for recurrent. default is sigmoid
+   * DropOutRate: dropout rate
    *
    * */
-  std::tuple<props::Unit, props::HiddenStateActivation,
-             props::RecurrentActivation, props::DropOutRate,
-             props::IntegrateBias, props::ResetAfter, props::MaxTimestep,
-             props::Timestep>
+  std::tuple<props::Unit, props::IntegrateBias, props::ResetAfter,
+             props::HiddenStateActivation, props::RecurrentActivation,
+             props::DropOutRate>
     grucell_props;
-  std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 7> wt_idx; /**< indices of the weights */
 
   /**
    * @brief     activation function for h_t : default is sigmoid
index f2263eb..73f8c89 100644 (file)
Binary files a/packaging/unittest_layers_v2.tar.gz and b/packaging/unittest_layers_v2.tar.gz differ
index 28b6020..f752d8b 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index 3ff2976..d906ce0 100644 (file)
@@ -128,7 +128,7 @@ if __name__ == "__main__":
 
     gru = K.layers.GRU(units=5, activation="tanh", 
                          recurrent_activation="sigmoid",
-                         bias_initializer='GlorotUniform',
+                         bias_initializer='glorot_uniform',
                          return_sequences=False,
                          return_state=False,
                          reset_after=False)
@@ -137,7 +137,7 @@ if __name__ == "__main__":
 
     gru = K.layers.GRU(units=5, activation="tanh", 
                          recurrent_activation="sigmoid",
-                         bias_initializer='GlorotUniform',
+                         bias_initializer='glorot_uniform',
                          return_sequences=True,
                          return_state=False,
                          reset_after=False)
@@ -146,7 +146,7 @@ if __name__ == "__main__":
 
     gru = K.layers.GRU(units=5, activation="sigmoid", 
                          recurrent_activation="tanh",
-                         bias_initializer='GlorotUniform',
+                         bias_initializer='glorot_uniform',
                          return_sequences=True,
                          return_state=False,
                          reset_after=False,)
@@ -155,7 +155,7 @@ if __name__ == "__main__":
     # check reset_after
     gru = K.layers.GRU(units=5, activation="tanh", 
                          recurrent_activation="sigmoid",
-                         bias_initializer='GlorotUniform',
+                         bias_initializer='glorot_uniform',
                          return_sequences=False,
                          return_state=False,
                          reset_after=True,)
@@ -164,7 +164,7 @@ if __name__ == "__main__":
 
     gru = K.layers.GRU(units=5, activation="tanh", 
                          recurrent_activation="sigmoid",
-                         bias_initializer='GlorotUniform',
+                         bias_initializer='glorot_uniform',
                          return_sequences=True,
                          return_state=False,
                          reset_after=True)
@@ -173,12 +173,32 @@ if __name__ == "__main__":
 
     gru = K.layers.GRU(units=5, activation="sigmoid", 
                          recurrent_activation="tanh",
-                         bias_initializer='GlorotUniform',
+                         bias_initializer='glorot_uniform',
                          return_sequences=True,
                          return_state=False,
                          reset_after=True)
     record_single(gru, (3, 4, 7), "gru_reset_after_multi_step_seq_act", input_type='float')
 
+    unit, batch_size, unroll_for, feature_size = [5, 3, 1, 7]
+    grucell = K.layers.GRUCell(units=unit,
+                         recurrent_activation='sigmoid',
+                         bias_initializer='glorot_uniform')
+    record_single(grucell, [(batch_size, feature_size), (batch_size, unit)], "grucell_single_step", input_type='float')
+
+    unit, batch_size, unroll_for, feature_size = [5, 3, 1, 7]
+    grucell = K.layers.GRUCell(units=unit,
+                         recurrent_activation='sigmoid',
+                         bias_initializer='glorot_uniform',
+                         reset_after=True)
+    record_single(grucell, [(batch_size, feature_size), (batch_size, unit)], "grucell_reset_after_single_step", input_type='float')
+
+    unit, batch_size, unroll_for, feature_size = [5, 3, 1, 7]
+    grucell = K.layers.GRUCell(units=unit,
+                         activation="sigmoid",
+                         recurrent_activation="tanh",
+                         bias_initializer='glorot_uniform')
+    record_single(grucell, [(batch_size, feature_size), (batch_size, unit)], "grucell_single_step_act", input_type='float')
+
     dropout = K.layers.Dropout(rate=0.2)
     record_single(dropout, (2, 3, 2, 3), "dropout_20_training", {"training": True})
     record_single(dropout, (2, 3, 2, 3), "dropout_20_inference", {"training": False})
index c20f3d4..47e1fd8 100644 (file)
@@ -145,21 +145,21 @@ class ZoneoutLSTMStacked(torch.nn.Module):
         return ret, loss
 
 class GRUCellStacked(torch.nn.Module):
-    def __init__(self, unroll_for=2, num_gru=1):
+    def __init__(self, unroll_for=2, num_grucell=1):
         super().__init__()
         self.input_size = self.hidden_size = 2
         self.grus = torch.nn.ModuleList(
             [
                 torch.nn.GRUCell(self.input_size, self.hidden_size, bias=True)
-                for _ in range(num_gru)
+                for _ in range(num_grucell)
             ]
         )
         self.unroll_for = unroll_for
         self.loss = torch.nn.MSELoss()
 
     def forward(self, inputs, labels):
-        hs = [torch.zeros_like(inputs[0]) for _ in self.grus]
         out = inputs[0]
+        hs = inputs[1:]
         ret = []
         for _ in range(self.unroll_for):
             for i, (gru, h) in enumerate(zip(self.grus, hs)):
@@ -409,19 +409,21 @@ if __name__ == "__main__":
         name="zoneout_lstm_stacked_100_100",
     )
 
+    unroll_for, num_grucell, batch_size, unit, feature_size, iteration, = [2, 1, 3, 2, 2, 2]
     record_v2(
-        GRUCellStacked(unroll_for=2, num_gru=1),
-        iteration=2,
-        input_dims=[(3, 2)],
-        label_dims=[(3, 2, 2)],
+        GRUCellStacked(unroll_for=unroll_for, num_grucell=num_grucell),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_grucell)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="grucell_single",
     )
 
+    unroll_for, num_grucell, batch_size, unit, feature_size, iteration, = [2, 2, 3, 2, 2, 2]
     record_v2(
-        GRUCellStacked(unroll_for=2, num_gru=2),
-        iteration=2,
-        input_dims=[(3, 2)],
-        label_dims=[(3, 2, 2)],
+        GRUCellStacked(unroll_for=unroll_for, num_grucell=num_grucell),
+        iteration=iteration,
+        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_grucell)],
+        label_dims=[(batch_size, unroll_for, unit)],
         name="grucell_stacked",
     )
 
index c859f3d..ac45a8c 100644 (file)
@@ -244,6 +244,33 @@ class GRUTransLayer(IdentityTransLayer):
     def to_nntr_trainable_weights(self, tensorOrList):
         return self.to_nntr_weights(tensorOrList)
 
+class GRUCellTransLayer(IdentityTransLayer):
+    def build(self, input_shape):
+        if not self.built:
+            self.tf_layer.build(input_shape[0])
+            super().build(input_shape[0])
+
+    ##
+    # @brief call function
+    # @param inputs input with nntrainer layout
+    def call(self, inputs):
+        input = inputs[0]
+        states = inputs[1:]
+        output, states = self.tf_layer.call(input, states)
+        # print(output)
+        return output
+
+    def to_nntr_weights(self, tensorOrList):
+        bias = tensorOrList[2]
+        if bias.shape.rank == 2:
+            bias_ih, bias_hh = bias[0], bias[1]
+            return [tensorOrList[0], tensorOrList[1], bias_ih, bias_hh]
+        else:
+            return tensorOrList
+
+    def to_nntr_trainable_weights(self, tensorOrList):
+        return self.to_nntr_weights(tensorOrList)
+
 ##
 # @brief A factory function to attach translayer to existing layer
 # if nothing should be attached, it does not attach the layer
@@ -262,4 +289,7 @@ def attach_trans_layer(layer):
     if isinstance(layer, K.layers.GRU):
         return GRUTransLayer(layer)
 
+    if isinstance(layer, K.layers.GRUCell):
+        return GRUCellTransLayer(layer)
+
     return layer
index b331d54..c9c0e3f 100644 (file)
 #include <grucell.h>
 #include <layers_common_tests.h>
 
-auto semantic_grucell =
-  LayerSemanticsParamType(nntrainer::createLayer<nntrainer::GRUCellLayer>,
-                          nntrainer::GRUCellLayer::type,
-                          {"unit=1", "max_timestep=1", "timestep=0",
-                           "integrate_bias=false", "reset_after=true"},
-                          0, false, 1);
+auto semantic_grucell = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::GRUCellLayer>,
+  nntrainer::GRUCellLayer::type,
+  {"unit=1", "integrate_bias=false", "reset_after=true"}, 0, false, 2);
 
 INSTANTIATE_TEST_CASE_P(GRUCell, LayerSemantics,
                         ::testing::Values(semantic_grucell));
 
-auto grucell_single_step =
-  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::GRUCellLayer>,
-                           {"unit=5", "max_timestep=1", "timestep=0",
-                            "integrate_bias=true", "reset_after=false"},
-                           "3:1:1:7", "gru_single_step.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+auto grucell_single_step = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::GRUCellLayer>,
+  {"unit=5", "integrate_bias=true", "reset_after=false"}, "3:1:1:7,3:1:1:5",
+  "grucell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+
+auto grucell_reset_after_single_step = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::GRUCellLayer>,
+  {"unit=5", "integrate_bias=false", "reset_after=true"}, "3:1:1:7,3:1:1:5",
+  "grucell_reset_after_single_step.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT);
+
+auto grucell_single_step_act = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::GRUCellLayer>,
+  {"unit=5", "integrate_bias=true", "reset_after=false",
+   "hidden_state_activation=sigmoid", "recurrent_activation=tanh"},
+  "3:1:1:7,3:1:1:5", "grucell_single_step_act.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT);
 
 INSTANTIATE_TEST_CASE_P(GRUCell, LayerGoldenTest,
-                        ::testing::Values(grucell_single_step));
+                        ::testing::Values(grucell_single_step,
+                                          grucell_reset_after_single_step,
+                                          grucell_single_step_act));
index d7ac3b2..7859851 100644 (file)
@@ -459,6 +459,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleGRUCell() {
 
   auto outer_graph = makeGraph({
     {"input", {"name=input", "input_shape=1:1:2"}},
+    {"input", {"name=input_hidden_state", "input_shape=1:1:2"}},
     /// here grucell is being inserted
     {"mse", {"name=loss", "input_layers=grucell_scope/a1"}},
   });
@@ -467,19 +468,24 @@ static std::unique_ptr<NeuralNetwork> makeSingleGRUCell() {
   }
 
   auto grucell = makeGraph({
+    {"input", {"name=dummy_0", "input_shape=1"}},
+    {"input", {"name=dummy_1", "input_shape=1"}},
     {"grucell",
-     {"name=a1", "unit=2", "integrate_bias=false", "reset_after=true"}},
+     {"name=a1", "unit=2", "integrate_bias=false", "reset_after=true",
+      "input_layers=dummy_0, dummy_1"}},
   });
 
-  nn->addWithReferenceLayers(grucell, "grucell_scope", {"input"}, {"a1"},
-                             {"a1"}, ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a1",
-                               "recurrent_input=a1",
-                               "recurrent_output=a1",
-                             });
+  nn->addWithReferenceLayers(
+    grucell, "grucell_scope", {"input", "input_hidden_state"},
+    {"a1(0)", "a1(1)"}, {"a1"}, ml::train::ReferenceLayersType::RECURRENT,
+    {
+      "unroll_for=2",
+      "as_sequence=a1",
+      "recurrent_input=a1(0), a1(1)",
+      "recurrent_output=a1(0), a1(0)",
+    });
 
+  nn->setProperty({"input_layers=input, input_hidden_state"});
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }
@@ -490,30 +496,41 @@ static std::unique_ptr<NeuralNetwork> makeStackedGRUCell() {
 
   auto outer_graph = makeGraph({
     {"input", {"name=input", "input_shape=1:1:2"}},
+    {"input", {"name=a1_input_hidden_state", "input_shape=1:1:2"}},
+    {"input", {"name=a2_input_hidden_state", "input_shape=1:1:2"}},
     /// here grucells are being inserted
-    {"mse", {"name=loss", "input_layers=grucell_scope/a2"}},
+    {"mse", {"name=loss", "input_layers=grucell_scope/a2(0)"}},
   });
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
 
   auto grucell = makeGraph({
+    {"input", {"name=dummy_0", "input_shape=1"}},
+    {"input", {"name=dummy_1", "input_shape=1"}},
+    {"input", {"name=dummy_2", "input_shape=1"}},
     {"grucell",
-     {"name=a1", "unit=2", "integrate_bias=false", "reset_after=true"}},
+     {"name=a1", "unit=2", "integrate_bias=false", "reset_after=true",
+      "input_layers=dummy_0, dummy_1"}},
     {"grucell",
      {"name=a2", "unit=2", "integrate_bias=false", "reset_after=true",
-      "input_layers=a1"}},
+      "input_layers=a1(0), dummy_2"}},
   });
 
-  nn->addWithReferenceLayers(grucell, "grucell_scope", {"input"}, {"a1"},
-                             {"a2"}, ml::train::ReferenceLayersType::RECURRENT,
-                             {
-                               "unroll_for=2",
-                               "as_sequence=a2",
-                               "recurrent_input=a1",
-                               "recurrent_output=a2",
-                             });
+  nn->addWithReferenceLayers(
+    grucell, "grucell_scope",
+    {"input", "a1_input_hidden_state", "a2_input_hidden_state"},
+    {"a1(0)", "a1(1)", "a2(1)"}, {"a2"},
+    ml::train::ReferenceLayersType::RECURRENT,
+    {
+      "unroll_for=2",
+      "as_sequence=a2",
+      "recurrent_input=a1(0), a1(1), a2(1)",
+      "recurrent_output=a2(0), a1(0), a2(0)",
+    });
 
+  nn->setProperty(
+    {"input_layers=input, a1_input_hidden_state, a2_input_hidden_state"});
   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
   return nn;
 }
@@ -576,9 +593,8 @@ INSTANTIATE_TEST_CASE_P(
                  ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedRNNCell, "rnncell_stacked__1",
                  ModelTestOption::ALL_V2),
-    mkModelTc_V2(makeSingleGRUCell, "grucell_single__1",
-                 ModelTestOption::ALL_V2),
-    mkModelTc_V2(makeStackedGRUCell, "grucell_stacked__1",
+    mkModelTc_V2(makeSingleGRUCell, "grucell_single", ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeStackedGRUCell, "grucell_stacked",
                  ModelTestOption::ALL_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {