[Test] Add recurrent value compare tests
authorJihoon Lee <jhoon.it.lee@samsung.com>
Tue, 19 Oct 2021 11:19:01 +0000 (20:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 20 Oct 2021 12:19:55 +0000 (21:19 +0900)
This patch adds recurrent value compares golden tests

There are 4 cases presented.

1. single fc recurrent
2. stacked fc recurrent
3. single lstm recurrent
4. stacked lstm recurrent

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
.gitignore
nntrainer/layers/lstm.cpp
packaging/unittest_models_v2.tar.gz [new file with mode: 0644]
test/input_gen/genModelsRecurrent_v2.py [new file with mode: 0644]
test/input_gen/recorder_v2.py
test/unittest/meson.build
test/unittest/models/unittest_models_recurrent.cpp

index c4d8534..5678f6f 100644 (file)
@@ -43,6 +43,7 @@ Applications/**/*.bin
 *.a
 *.o.d
 *.nnlayergolden
+*.nnmodelgolden
 
 # log files
 *.log
index 9248c3a..1d964f3 100644 (file)
@@ -395,6 +395,7 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
       Tensor rdata =
         incoming_deriv.getSharedDataTensor({d.width()}, b * d.width());
       /// @note this is not copying from start ~ end but only start time step
+      // This is copying for self rolling as well as last recurrent unrolled.
       if ((unsigned)start_timestep + 1 == max_timestep) {
         data.fill(rdata);
       } else {
@@ -471,7 +472,8 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
         acti_func.run_prime_fn(cs, dc, dh);
         dc.multiply_i(ho);
       } else {
-        /// @todo optimize this by updating run_prime_fn to accumulate
+        /// @todo optimize this by updating run_prime_fn to accumulate or make
+        /// it inplace somehow
         Tensor dc_temp(dc.getDim());
         acti_func.run_prime_fn(cs, dc_temp, dh);
         dc_temp.multiply_i(ho);
diff --git a/packaging/unittest_models_v2.tar.gz b/packaging/unittest_models_v2.tar.gz
new file mode 100644 (file)
index 0000000..d689aeb
Binary files /dev/null and b/packaging/unittest_models_v2.tar.gz differ
diff --git a/test/input_gen/genModelsRecurrent_v2.py b/test/input_gen/genModelsRecurrent_v2.py
new file mode 100644 (file)
index 0000000..6f968da
--- /dev/null
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+##
+# Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+#
+# @file genModelsRecurrent_v2.py
+# @date 19 October 2021
+# @brief Generate recurrent model tcs
+# @author Jihoon lee <jhoon.it.lee@samsung.com>
+
+from recorder_v2 import record_v2, inspect_file
+import torch
+
+class FCUnroll(torch.nn.Module):
+    def __init__(self, unroll_for=1, num_fc=1):
+        super().__init__()
+        self.fcs = torch.nn.ModuleList([torch.nn.Linear(1, 1) for i in range(num_fc)])
+        self.unroll_for = unroll_for
+        # self.loss = torch.nn.MSELoss()
+        self.loss = torch.nn.Identity()
+
+    def forward(self, inputs, labels):
+        output = inputs[0]
+        for i in range(self.unroll_for):
+            for fc in self.fcs:
+                output = fc(output)
+        loss = self.loss(output)
+        # loss = self.loss(output, labels[0])
+        return output, loss
+
+
+class LSTMStacked(torch.nn.Module):
+    def __init__(self, unroll_for=2, num_lstm=1):
+        super().__init__()
+        self.input_size = self.hidden_size = 2
+        self.lstms = torch.nn.ModuleList(
+            [
+                torch.nn.LSTMCell(self.input_size, self.hidden_size, bias=True)
+                for _ in range(num_lstm)
+            ]
+        )
+        # self.lstm.weight_hh.data.fill_(1.0)
+        # self.lstm.weight_ih.data.fill_(1.0)
+        # self.lstm.bias_hh.data.fill_(1.0)
+        self.unroll_for = unroll_for
+        self.loss = torch.nn.MSELoss()
+
+    def forward(self, inputs, labels):
+        # second bias is always set to make it always zero grad.
+        # this is because that we are only keepting one bias
+        for lstm in self.lstms:
+            lstm.bias_ih.data.fill_(0.0)
+
+        hs = [torch.zeros_like(inputs[0]) for _ in self.lstms]
+        cs = [torch.zeros_like(inputs[0]) for _ in self.lstms]
+        out = inputs[0]
+        ret = []
+        for _ in range(self.unroll_for):
+            for i, (lstm, h, c) in enumerate(zip(self.lstms, hs, cs)):
+                hs[i], cs[i] = lstm(out, (h, c))
+                out = hs[i]
+            ret.append(out)
+
+        ret = torch.stack(ret, dim=1)
+        loss = self.loss(ret, labels[0])
+        return ret, loss
+
+
+if __name__ == "__main__":
+    record_v2(
+        FCUnroll(unroll_for=5),
+        iteration=2,
+        input_dims=[(1,)],
+        label_dims=[(1,)],
+        name="fc_unroll_single",
+    )
+
+    record_v2(
+        FCUnroll(unroll_for=2, num_fc=2),
+        iteration=2,
+        input_dims=[(1,)],
+        label_dims=[(1,)],
+        name="fc_unroll_stacked",
+    )
+
+    record_v2(
+        LSTMStacked(unroll_for=2, num_lstm=1),
+        iteration=2,
+        input_dims=[(3, 2)],
+        label_dims=[(3, 2, 2)],
+        name="lstm_single",
+    )
+
+    record_v2(
+        LSTMStacked(unroll_for=2, num_lstm=2),
+        iteration=2,
+        input_dims=[(3, 2)],
+        label_dims=[(3, 2, 2)],
+        name="lstm_stacked",
+    )
+
+    # inspect_file("lstm_single.nnmodelgolden")
index 44f9daa..dfdc383 100644 (file)
 
 import os
 import random
-import torch
+import torch  # torch used here is torch==1.9.1
 import numpy as np
 
 from transLayer_v2 import params_translated
 
+if torch.__version__ != "1.9.1":
+    print(
+        "the script was tested at version 1.9.1 it might not work if torch version is different"
+    )
+
 SEED = 1234
 random.seed(SEED)
 np.random.seed(SEED)
@@ -44,6 +49,7 @@ def _rand_like(*shapes, scale=1, rand="int"):
     np_array = map(shape_to_np, shapes)
     return [torch.tensor(t * scale) for t in np_array]
 
+
 ##
 # @brief record a torch model
 # @param iteration number of iteration to record
@@ -64,8 +70,6 @@ def record_v2(model, iteration, input_dims, label_dims, name):
 
     optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
 
-    print(*(model.named_parameters()))
-
     def record_iteration(write_fn):
         inputs = _rand_like(*input_dims, rand="float")
         labels = _rand_like(*label_dims, rand="float")
index c8781e1..3642286 100644 (file)
@@ -11,6 +11,7 @@ unzip_target = [
   ['unittest_layers.tar.gz', 'unittest_layers'],
   ['unittest_layers_v2.tar.gz', 'unittest_layers'],
   ['unittest_models.tar.gz', 'unittest_models'],
+  ['unittest_models_v2.tar.gz', 'unittest_models'],
 ]
 
 src_path = meson.source_root() / 'packaging'
index 7093101..e8df1c7 100644 (file)
 
 using namespace nntrainer;
 
+static inline constexpr const int NOT_USED_ = 1;
+
 static IniSection nn_base("model", "type = NeuralNetwork");
 static std::string fc_base = "type = Fully_connected";
 static IniSection sgd_base("optimizer", "Type = sgd");
 static IniSection constant_loss("loss", "type = constant_derivative");
 
-IniWrapper
-  fc_only_hand_unrolled("fc_only_hand_unrolled",
-                        {
-                          nn_base,
-                          sgd_base,
-                          IniSection("fc_1") + fc_base +
-                            "unit=1 | weight_initializer=ones | "
-                            "bias_initializer=ones | input_shape=1:1:1",
-                          IniSection("fc_2") + fc_base +
-                            "unit=1 | weight_initializer=ones | "
-                            "bias_initializer=ones | shared_from = fc_1",
-                          IniSection("fc_3") + fc_base +
-                            "unit=1 | weight_initializer=ones | "
-                            "bias_initializer=ones | shared_from = fc_1",
-                          constant_loss,
-                        });
-
-std::unique_ptr<NeuralNetwork> makeSingleLSTM() {
+IniWrapper fc_unroll_single(
+  "fc_unroll_single",
+  {
+    nn_base,
+    sgd_base + "learning_rate=0.1",
+    IniSection("fc_1") + fc_base + "unit=1 | input_shape=1:1:1",
+    IniSection("fc_2") + fc_base + "unit=1 | shared_from = fc_1",
+    IniSection("fc_3") + fc_base + "unit=1 | shared_from = fc_1",
+    IniSection("fc_4") + fc_base + "unit=1 | shared_from = fc_1",
+    IniSection("fc_5") + fc_base + "unit=1 | shared_from = fc_1",
+    constant_loss,
+  });
+
+std::unique_ptr<NeuralNetwork> makeFC() {
   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
   nn->setProperty({"batch_size=1"});
 
   auto outer_graph = makeGraph({
     {"input", {"name=input", "input_shape=1:1:1"}},
     /// here lstm_cells is being inserted
+    {"constant_derivative", {"name=loss", "input_layers=recurrent/a2"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto fcfc = makeGraph({
+    {"Fully_connected", {"name=a1", "unit=1"}},
+    {"Fully_connected", {"name=a2", "unit=1", "input_layers=a1"}},
+  });
+
+  nn->addWithReferenceLayers(fcfc, "recurrent", {"input"}, {"a1"}, {"a2"},
+                             ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=false",
+                               "recurrent_input=a1",
+                               "recurrent_output=a2",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
+static std::unique_ptr<NeuralNetwork> makeSingleLSTM() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here lstm_cells is being inserted
     {"mse", {"name=loss", "input_layers=lstm_scope/a1"}},
   });
   for (auto &node : outer_graph) {
@@ -60,7 +89,7 @@ std::unique_ptr<NeuralNetwork> makeSingleLSTM() {
   }
 
   auto lstm = makeGraph({
-    {"lstm", {"name=a1", "input_shape=1:1:1", "unit=1"}},
+    {"lstm", {"name=a1", "unit=2"}},
   });
 
   nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a1"},
@@ -76,14 +105,45 @@ std::unique_ptr<NeuralNetwork> makeSingleLSTM() {
   return nn;
 }
 
+static std::unique_ptr<NeuralNetwork> makeStackedLSTM() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here lstm_cells is being inserted
+    {"mse", {"name=loss", "input_layers=lstm_scope/a2"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto lstm = makeGraph({
+    {"lstm", {"name=a1", "unit=2"}},
+    {"lstm", {"name=a2", "unit=2", "input_layers=a1"}},
+  });
+
+  nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a2"},
+                             ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=true",
+                               "recurrent_input=a1",
+                               "recurrent_output=a2",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 INSTANTIATE_TEST_CASE_P(
   recurrentModels, nntrainerModelTest,
   ::testing::ValuesIn({
-    mkModelIniTc(fc_only_hand_unrolled, "1:1:1", 1,
-                 ModelTestOption::NO_THROW_RUN),
-    /// @todo make below COMPARE
-    mkModelTc(makeSingleLSTM, "lstm_return_sequence", "1:2:1", 1,
-              ModelTestOption::NO_THROW_RUN),
+    mkModelIniTc(fc_unroll_single, DIM_UNUSED, NOT_USED_,
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeFC, "fc_unroll_stacked", ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::COMPARE_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);