[ Recurrent ] property for dynamic time sequence
authorjijoong.moon <jijoong.moon@samsung.com>
Wed, 15 Jun 2022 11:37:08 +0000 (20:37 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 16 Jun 2022 07:40:04 +0000 (16:40 +0900)
.This patch provides property for dynamic time sequence in recurrent
realizer. The sementic of this is "dynamic_time_seq = true/false"

.Add grucell + Fully Connected uint test case for reference of dynamic
time sequence

Related : #1933

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/compiler/recurrent_realizer.h
test/input_gen/genModelsRecurrent_v2.py
test/unittest/models/unittest_models_recurrent.cpp

index 6095450..baf867d 100644 (file)
@@ -46,6 +46,22 @@ public:
 UnrollFor::UnrollFor(const unsigned &value) { set(value); }
 
 /**
+ * @brief dynamic time sequence property, use this to set and check if dynamic
+ * time sequence is enabled.
+ *
+ */
+class DynamicTimeSequence final : public nntrainer::Property<bool> {
+public:
+  /**
+   * @brief Construct a new DynamicTimeSequence object
+   *
+   */
+  DynamicTimeSequence(bool val = true) : nntrainer::Property<bool>(val) {}
+  static constexpr const char *key = "dynamic_time_seq";
+  using prop_tag = bool_prop_tag;
+};
+
+/**
  * @brief Property for recurrent inputs
  *
  */
@@ -105,7 +121,7 @@ RecurrentRealizer::RecurrentRealizer(const std::vector<std::string> &properties,
   recurrent_props(new PropTypes(
     std::vector<props::RecurrentInput>(), std::vector<props::RecurrentOutput>(),
     std::vector<props::AsSequence>(), props::UnrollFor(1),
-    std::vector<props::InputIsSequence>())) {
+    std::vector<props::InputIsSequence>(), props::DynamicTimeSequence(false))) {
   auto left = loadProperties(properties, *recurrent_props);
 
   std::transform(input_conns.begin(), input_conns.end(),
@@ -131,8 +147,8 @@ RecurrentRealizer::RecurrentRealizer(const std::vector<std::string> &properties,
     }
   }
 
-  auto &[inputs, outputs, as_sequence, unroll_for, input_is_seq] =
-    *recurrent_props;
+  auto &[inputs, outputs, as_sequence, unroll_for, input_is_seq,
+         dynamic_time_seq] = *recurrent_props;
 
   NNTR_THROW_IF(inputs.empty() || inputs.size() != outputs.size(),
                 std::invalid_argument)
index ac81f8d..e6e390c 100644 (file)
@@ -33,6 +33,7 @@ class InputIsSequence;
 class OutputLayer;
 class RecurrentInput;
 class RecurrentOutput;
+class DynamicTimeSequence;
 } // namespace props
 
 /**
@@ -87,10 +88,11 @@ public:
   GraphRepresentation realize(const GraphRepresentation &reference) override;
 
 private:
-  using PropTypes = std::tuple<std::vector<props::RecurrentInput>,
-                               std::vector<props::RecurrentOutput>,
-                               std::vector<props::AsSequence>, props::UnrollFor,
-                               std::vector<props::InputIsSequence>>;
+  using PropTypes =
+    std::tuple<std::vector<props::RecurrentInput>,
+               std::vector<props::RecurrentOutput>,
+               std::vector<props::AsSequence>, props::UnrollFor,
+               std::vector<props::InputIsSequence>, props::DynamicTimeSequence>;
 
   std::unordered_set<std::string> input_layers; /**< external input layers */
   std::unordered_set<std::string>
index e7fc61b..d02dd92 100644 (file)
@@ -13,7 +13,7 @@ from zoneout import Zoneout
 import torch
 
 class FCUnroll(torch.nn.Module):
-    def __init__(self, unroll_for=1, num_fc=1):
+    def __init__(self, unroll_for = 1, num_fc = 1):
         super().__init__()
         self.fcs = torch.nn.ModuleList([torch.nn.Linear(1, 1) for i in range(num_fc)])
         self.unroll_for = unroll_for
@@ -30,7 +30,7 @@ class FCUnroll(torch.nn.Module):
         return output, loss
 
 class RNNCellStacked(torch.nn.Module):
-    def __init__(self, unroll_for=2, num_rnncell=1, input_size=2, hidden_size=2):
+    def __init__(self, unroll_for = 2, num_rnncell = 1, input_size = 2, hidden_size = 2):
         super().__init__()
         self.rnncells = torch.nn.ModuleList(
             [
@@ -51,19 +51,19 @@ class RNNCellStacked(torch.nn.Module):
                 out = hs[i]
             ret.append(out)
 
-        ret = torch.stack(ret, dim=1)
+        ret = torch.stack(ret, dim = 1)
         loss = self.loss(ret, labels[0])
         return ret, loss
 
 class LSTMStacked(torch.nn.Module):
-    def __init__(self, num_lstm=1, bidirectional=False):
+    def __init__(self, num_lstm = 1, bidirectional = False):
         super().__init__()
         self.input_size = self.hidden_size = 2
         self.num_lstm = num_lstm
-        self.bidirectional=bidirectional
+        self.bidirectional = bidirectional
         self.lstms = torch.nn.ModuleList(
             [
-                torch.nn.LSTM(self.input_size if self.bidirectional == False or i == 0 else 2 * self.input_size, self.hidden_size, batch_first=True, bidirectional=bidirectional)
+                torch.nn.LSTM(self.input_size if self.bidirectional == False or i == 0 else 2 * self.input_size, self.hidden_size, batch_first = True, bidirectional = bidirectional)
                 # Intended comment
                 # torch.nn.LSTM(self.input_size if self.bidirectional == False or i == 0 else 2 * self.input_size, self.hidden_size, num_layers=num_lstm, batch_first=True, bidirectional=bidirectional)
                 for i in range(num_lstm)
@@ -85,7 +85,7 @@ class LSTMStacked(torch.nn.Module):
         return out, loss
 
 class LSTMCellStacked(torch.nn.Module):
-    def __init__(self, unroll_for=2, num_lstmcell=1):
+    def __init__(self, unroll_for = 2, num_lstmcell = 1):
         super().__init__()
         self.input_size = self.hidden_size = 2
         self.lstmcells = torch.nn.ModuleList(
@@ -110,12 +110,12 @@ class LSTMCellStacked(torch.nn.Module):
                 out = hs[i]
             ret.append(out)
 
-        ret = torch.stack(ret, dim=1)
+        ret = torch.stack(ret, dim = 1)
         loss = self.loss(ret, labels[0])
         return ret, loss
 
 class ZoneoutLSTMStacked(torch.nn.Module):
-    def __init__(self, batch_size=3, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1, cell_state_zoneout_rate=1):
+    def __init__(self, batch_size = 3, unroll_for = 2, num_lstm = 1, hidden_state_zoneout_rate = 1, cell_state_zoneout_rate = 1):
         super().__init__()
         self.input_size = self.hidden_size = 2
         self.cell_state_zoneout_rate = cell_state_zoneout_rate
@@ -141,17 +141,17 @@ class ZoneoutLSTMStacked(torch.nn.Module):
                 out = hs[i]
             ret.append(out)
 
-        ret = torch.stack(ret, dim=1)
+        ret = torch.stack(ret, dim = 1)
         loss = self.loss(ret, labels[0])
         return ret, loss
 
 class GRUCellStacked(torch.nn.Module):
-    def __init__(self, unroll_for=2, num_grucell=1):
+    def __init__(self, unroll_for = 2, num_grucell = 1):
         super().__init__()
         self.input_size = self.hidden_size = 2
         self.grus = torch.nn.ModuleList(
             [
-                torch.nn.GRUCell(self.input_size, self.hidden_size, bias=True)
+                torch.nn.GRUCell(self.input_size, self.hidden_size, bias = True)
                 for _ in range(num_grucell)
             ]
         )
@@ -168,291 +168,322 @@ class GRUCellStacked(torch.nn.Module):
                 out = hs[i]
             ret.append(out)
 
-        ret = torch.stack(ret, dim=1)
+        ret = torch.stack(ret, dim = 1)
+        loss = self.loss(ret, labels[0])
+        return ret, loss
+
+class GRUCellFC(torch.nn.Module):
+    def __init__(self, unroll_for = 2, num_grucell = 1):
+        super().__init__()
+        self.input_size = self.hidden_size = 2
+        self.gru=torch.nn.GRUCell(self.input_size, self.hidden_size, bias = True)
+        self.fc = torch.nn.Linear(2,2)
+        self.unroll_for = unroll_for
+        self.loss = torch.nn.MSELoss()
+
+    def forward(self, inputs, labels):
+        out = inputs[0]
+        hs = inputs[1]
+        ret = []
+        for _ in range(self.unroll_for):
+            hs = self.gru(out, hs)
+            out = self.fc(hs)
+            ret.append(out)
+
+        ret = torch.stack(ret, dim = 1)
         loss = self.loss(ret, labels[0])
         return ret, loss
 
 if __name__ == "__main__":
     record_v2(
-        FCUnroll(unroll_for=5),
-        iteration=2,
-        input_dims=[(1,)],
-        label_dims=[(1,)],
-        name="fc_unroll_single",
+        FCUnroll(unroll_for = 5),
+        iteration = 2,
+        input_dims = [(1,)],
+        label_dims = [(1,)],
+        name = "fc_unroll_single",
     )
 
     record_v2(
-        FCUnroll(unroll_for=2, num_fc=2),
-        iteration=2,
-        input_dims=[(1,)],
-        label_dims=[(1,)],
-        name="fc_unroll_stacked",
+        FCUnroll(unroll_for = 2, num_fc = 2),
+        iteration = 2,
+        input_dims = [(1,)],
+        label_dims = [(1,)],
+        name = "fc_unroll_stacked",
     )
 
     record_v2(
-        FCUnroll(unroll_for=2, num_fc=2),
-        iteration=2,
-        input_dims=[(1,)],
-        label_dims=[(1,)],
-        name="fc_unroll_stacked_clipped",
-        clip=True
+        FCUnroll(unroll_for = 2, num_fc = 2),
+        iteration = 2,
+        input_dims = [(1,)],
+        label_dims = [(1,)],
+        name = "fc_unroll_stacked_clipped",
+        clip = True
     )
 
 
     unroll_for, num_rnncell, batch_size, unit, feature_size, iteration = [2, 1, 3, 2, 2, 2]
     record_v2(
-        RNNCellStacked(unroll_for=unroll_for, num_rnncell=num_rnncell, input_size=feature_size, hidden_size=unit),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_rnncell)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="rnncell_single",
+        RNNCellStacked(unroll_for = unroll_for, num_rnncell = num_rnncell, input_size = feature_size, hidden_size = unit),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_rnncell)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "rnncell_single",
     )
 
     unroll_for, num_rnncell, batch_size, unit, feature_size, iteration = [2, 2, 3, 2, 2, 2]
     record_v2(
-        RNNCellStacked(unroll_for=unroll_for, num_rnncell=num_rnncell, input_size=feature_size, hidden_size=unit),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_rnncell)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="rnncell_stacked",
+        RNNCellStacked(unroll_for = unroll_for, num_rnncell = num_rnncell, input_size = feature_size, hidden_size = unit),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_rnncell)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "rnncell_stacked",
     )
 
     unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 1, 3, 2, 2, 2, False]
     record_v2(
-        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
-        iteration=iteration,
-        input_dims=[(batch_size, unroll_for, feature_size)],
-        # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="lstm_single",
+        LSTMStacked(num_lstm = num_lstm, bidirectional = bidirectional),
+        iteration = iteration,
+        input_dims = [(batch_size, unroll_for, feature_size)],
+        # input_dims = [(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "lstm_single",
     )
 
     unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 2, 3, 2, 2, 2, False]
     record_v2(
-        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
-        iteration=iteration,
-        input_dims=[(batch_size, unroll_for, feature_size)],
-        # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="lstm_stacked",
+        LSTMStacked(num_lstm = num_lstm, bidirectional = bidirectional),
+        iteration = iteration,
+        input_dims = [(batch_size, unroll_for, feature_size)],
+        # input_dims = [(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "lstm_stacked",
     )
 
     unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 1, 3, 2, 2, 2, True]
     record_v2(
-        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
-        iteration=iteration,
-        input_dims=[(batch_size, unroll_for, feature_size)],
-        # input_dims=[(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)],
-        label_dims=[(batch_size, unroll_for, 2 * unit)],
-        name="bidirectional_lstm_single",
+        LSTMStacked(num_lstm = num_lstm, bidirectional = bidirectional),
+        iteration = iteration,
+        input_dims = [(batch_size, unroll_for, feature_size)],
+        # input_dims = [(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims = [(batch_size, unroll_for, 2 * unit)],
+        name = "bidirectional_lstm_single",
     )
 
     unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 2, 3, 2, 2, 2, True]
     record_v2(
-        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
-        iteration=iteration,
-        input_dims=[(batch_size, unroll_for, feature_size)],
-        # input_dims=[(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)],
-        label_dims=[(batch_size, unroll_for, 2 * unit)],
-        name="bidirectional_lstm_stacked",
+        LSTMStacked(num_lstm = num_lstm, bidirectional = bidirectional),
+        iteration = iteration,
+        input_dims = [(batch_size, unroll_for, feature_size)],
+        # input_dims = [(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims = [(batch_size, unroll_for, 2 * unit)],
+        name = "bidirectional_lstm_stacked",
     )
 
     unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 1, 2, 3, 2, 2, 2]
     record_v2(
-        LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="lstmcell_single",
+        LSTMCellStacked(unroll_for = unroll_for, num_lstmcell = num_lstmcell),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "lstmcell_single",
     )
 
     unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 2, 2, 3, 2, 2, 2]
     record_v2(
-        LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="lstmcell_stacked",
+        LSTMCellStacked(unroll_for = unroll_for, num_lstmcell = num_lstmcell),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "lstmcell_stacked",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_000_000",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_000_000",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_000_000",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_000_000",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.5, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_050_000",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_050_000",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.5, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_050_000",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_050_000",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 1.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_100_000",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_100_000",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 1.0, 0.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_100_000",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_100_000",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_000_050",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_000_050",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_000_050",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_000_050",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.5, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_050_050",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_050_050",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.5, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_050_050",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_050_050",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 1.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_100_050",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_100_050",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 1.0, 0.5]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_100_050",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_100_050",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_000_100",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_000_100",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_000_100",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_000_100",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.5, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_050_100",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_050_100",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 0.5, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_050_100",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_050_100",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 1.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_single_100_100",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_single_100_100",
     )
 
     unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 2, 2, 1, 2, 2, 2, 1.0, 1.0]
     record_v2(
-        ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="zoneout_lstm_stacked_100_100",
+        ZoneoutLSTMStacked(batch_size = batch_size, unroll_for = unroll_for, num_lstm = num_lstm, hidden_state_zoneout_rate = hidden_state_zoneout_rate, cell_state_zoneout_rate = cell_state_zoneout_rate),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstm)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "zoneout_lstm_stacked_100_100",
     )
 
     unroll_for, num_grucell, batch_size, unit, feature_size, iteration, = [2, 1, 3, 2, 2, 2]
     record_v2(
-        GRUCellStacked(unroll_for=unroll_for, num_grucell=num_grucell),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_grucell)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="grucell_single",
+        GRUCellStacked(unroll_for = unroll_for, num_grucell = num_grucell),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_grucell)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "grucell_single",
     )
 
     unroll_for, num_grucell, batch_size, unit, feature_size, iteration, = [2, 2, 3, 2, 2, 2]
     record_v2(
-        GRUCellStacked(unroll_for=unroll_for, num_grucell=num_grucell),
-        iteration=iteration,
-        input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_grucell)],
-        label_dims=[(batch_size, unroll_for, unit)],
-        name="grucell_stacked",
+        GRUCellStacked(unroll_for = unroll_for, num_grucell = num_grucell),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_grucell)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "grucell_stacked",
+    )
+
+    unroll_for, num_grucell, batch_size, unit, feature_size, iteration, = [2, 1, 3, 2, 2, 2]
+    record_v2(
+        GRUCellFC(unroll_for = unroll_for, num_grucell = num_grucell),
+        iteration = iteration,
+        input_dims = [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(num_grucell)],
+        label_dims = [(batch_size, unroll_for, unit)],
+        name = "grucell_fc",
     )
 
     # inspect_file("lstm_single.nnmodelgolden")
index 3ca4bc9..0f4cf5b 100644 (file)
@@ -573,6 +573,40 @@ static std::unique_ptr<NeuralNetwork> makeStackedGRUCell() {
   return nn;
 }
 
+static std::unique_ptr<NeuralNetwork> makeStackedGRUCellFC() {
+  auto nn = std::make_unique<NeuralNetwork>();
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    {"input", {"name=a1_input_hidden_state", "input_shape=1:1:2"}},
+    /// here grucells are being inserted
+    {"mse", {"name=loss", "input_layers=grucell_scope/fc(0)"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto grucell = makeGraph({
+    {"input", {"name=dummy_0", "input_shape=1"}},
+    {"input", {"name=dummy_1", "input_shape=1"}},
+    {"grucell",
+     {"name=a1", "unit=2", "integrate_bias=false", "reset_after=true",
+      "input_layers=dummy_0, dummy_1"}},
+    {"fully_connected", {"name=fc", "unit=2", "input_layers=a1(0)"}},
+  });
+
+  nn->addWithReferenceLayers(
+    grucell, "grucell_scope", {"input", "a1_input_hidden_state"},
+    {"a1(0)", "a1(1)"}, {"fc"}, ml::train::ReferenceLayersType::RECURRENT,
+    {"unroll_for=2", "as_sequence=fc", "recurrent_input=a1(0), a1(1)",
+     "recurrent_output=fc(0), a1(0)", "dynamic_time_seq=true"});
+
+  nn->setProperty({"input_layers=input, a1_input_hidden_state"});
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 INSTANTIATE_TEST_CASE_P(
   recurrentModels, nntrainerModelTest,
   ::testing::ValuesIn({
@@ -637,6 +671,7 @@ INSTANTIATE_TEST_CASE_P(
     mkModelTc_V2(makeSingleGRUCell, "grucell_single", ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedGRUCell, "grucell_stacked",
                  ModelTestOption::ALL_V2),
+    mkModelTc_V2(makeStackedGRUCellFC, "grucell_fc", ModelTestOption::ALL_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);