[unittest] Implement rnncell unittest
authorhyeonseok lee <hs89.lee@samsung.com>
Fri, 5 Nov 2021 07:14:32 +0000 (16:14 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 10 Nov 2021 09:11:55 +0000 (18:11 +0900)
 - Generate rnn, rnncell layer unittest
 - Generate model unittest which is composed of rnncell
 - Verified with multi timestep/stacked/loop_unrolling situation

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/compiler/recurrent_realizer.cpp
packaging/unittest_layers_v2.tar.gz
packaging/unittest_models_v2.tar.gz
test/input_gen/genLayerTests.py
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer_v2.py
test/unittest/layers/meson.build
test/unittest/layers/unittest_layers_rnn.cpp
test/unittest/layers/unittest_layers_rnncell.cpp [new file with mode: 0644]
test/unittest/models/unittest_models_recurrent.cpp

index 3e17130..d4daa4c 100644 (file)
@@ -20,6 +20,8 @@
 #include <nntrainer_error.h>
 #include <node_exporter.h>
 #include <remap_realizer.h>
+#include <rnn.h>
+#include <rnncell.h>
 #include <util_func.h>
 
 namespace nntrainer {
@@ -127,7 +129,8 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
 
   /** @todo add an interface to check if a layer supports a property */
   auto is_recurrent_type = [](LayerNode *node) {
-    return node->getType() == LSTMLayer::type ||
+    return node->getType() == RNNCellLayer::type ||
+           node->getType() == LSTMLayer::type ||
            node->getType() == LSTMCellLayer::type;
   };
 
index 1f8be21..d6b19f8 100644 (file)
Binary files a/packaging/unittest_layers_v2.tar.gz and b/packaging/unittest_layers_v2.tar.gz differ
index 06d35aa..cd19c8a 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index 4176235..d132423 100644 (file)
@@ -90,6 +90,12 @@ if __name__ == "__main__":
     record_single(attention, [(2, 5, 7), (2, 3, 7), (2, 3, 7)],
                  "attention_batched", {}, input_type='float')
 
+    rnn = K.layers.SimpleRNN(units=5,
+                         activation="tanh",
+                         return_sequences=False,
+                         return_state=False)
+    record_single(rnn, (3, 1, 7), "rnn_single_step")
+
     lstm = K.layers.LSTM(units=5,
                          recurrent_activation="sigmoid",
                          activation="tanh",
index cec9622..6a48787 100644 (file)
@@ -28,6 +28,40 @@ class FCUnroll(torch.nn.Module):
         # loss = self.loss(output, labels[0])
         return output, loss
 
+class RNNCellStacked(torch.nn.Module):
+    def __init__(self, unroll_for=1, num_rnn=1, input_size=1, hidden_size=1):
+        super().__init__()
+        self.rnns = torch.nn.ModuleList(
+            [
+                torch.nn.RNNCell(input_size, hidden_size)
+                for _ in range(num_rnn)
+            ]
+        )
+        for rnn in self.rnns:
+            rnn.bias_ih.data.fill_(0.0)
+            rnn.bias_ih.requires_grad=False
+        self.unroll_for = unroll_for
+        self.loss = torch.nn.MSELoss()
+
+    def forward(self, inputs, labels):
+        # second bias is always set to make it always zero grad.
+        # this is because that we are only keeping one bias
+        for rnn in self.rnns:
+            rnn.bias_ih.data.fill_(0.0)
+
+        hs = [torch.zeros_like(inputs[0]) for _ in self.rnns]
+        out = inputs[0]
+        ret = []
+        for _ in range(self.unroll_for):
+            for i, rnn in enumerate(self.rnns):
+                hs[i] = rnn(out, hs[i])
+                out = hs[i]
+            ret.append(out)
+
+        ret = torch.stack(ret, dim=1)
+        loss = self.loss(ret, labels[0])
+        return ret, loss
+
 class LSTMStacked(torch.nn.Module):
     def __init__(self, unroll_for=2, num_lstm=1):
         super().__init__()
@@ -85,6 +119,22 @@ if __name__ == "__main__":
     )
 
     record_v2(
+        RNNCellStacked(unroll_for=2, num_rnn=1, input_size=2, hidden_size=2),
+        iteration=2,
+        input_dims=[(3, 2)],
+        label_dims=[(3, 2, 2)],
+        name="rnncell_single",
+    )
+
+    record_v2(
+        RNNCellStacked(unroll_for=2, num_rnn=2, input_size=2, hidden_size=2),
+        iteration=2,
+        input_dims=[(3, 2)],
+        label_dims=[(3, 2, 2)],
+        name="rnncell_stacked",
+    )
+
+    record_v2(
         LSTMStacked(unroll_for=2, num_lstm=1),
         iteration=2,
         input_dims=[(3, 2)],
index 897d4d8..a4f6aba 100644 (file)
@@ -53,7 +53,7 @@ def bn1d_translate(model):
     yield from [mu, var, gamma, beta]
 
 
-@register_for_(torch.nn.LSTMCell)
+@register_for_((torch.nn.RNNCell, torch.nn.LSTMCell))
 def lstm_translate(model):
     params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
     bias = ("bias", params[2][1] + params[3][1])
index 596da1e..df37a1b 100644 (file)
@@ -43,6 +43,7 @@ test_target = [
   'unittest_layers_addition.cpp',
   'unittest_layers_multiout.cpp',
   'unittest_layers_rnn.cpp',
+  'unittest_layers_rnncell.cpp',
   'unittest_layers_lstm.cpp',
   'unittest_layers_lstmcell.cpp',
   'unittest_layers_gru.cpp',
index 95d2ce5..e839892 100644 (file)
@@ -21,3 +21,11 @@ auto semantic_rnn =
                           nntrainer::RNNLayer::type, {"unit=1"}, 0, false, 1);
 
 INSTANTIATE_TEST_CASE_P(RNN, LayerSemantics, ::testing::Values(semantic_rnn));
+
+auto rnn_single_step = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::RNNLayer>,
+  {"unit=5", "return_sequences=false"}, "3:1:1:7",
+  "rnn_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+
+INSTANTIATE_TEST_CASE_P(RNN, LayerGoldenTest,
+                        ::testing::Values(rnn_single_step));
diff --git a/test/unittest/layers/unittest_layers_rnncell.cpp b/test/unittest/layers/unittest_layers_rnncell.cpp
new file mode 100644 (file)
index 0000000..cf8594e
--- /dev/null
@@ -0,0 +1,33 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+ *
+ * @file unittest_layers_rnncell.cpp
+ * @date 1 November 2021
+ * @brief RNNCell Layer Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author hyeonseok lee <hs89.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <layers_common_tests.h>
+#include <rnncell.h>
+
+auto semantic_rnncell = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::RNNCellLayer>,
+  nntrainer::RNNCellLayer::type, {"unit=1", "timestep=0", "max_timestep=1"}, 0,
+  false, 1);
+
+INSTANTIATE_TEST_CASE_P(RNNCell, LayerSemantics,
+                        ::testing::Values(semantic_rnncell));
+
+auto rnncell_single_step = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::RNNCellLayer>,
+  {"unit=5", "timestep=0", "max_timestep=1"}, "3:1:1:7",
+  "rnn_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+
+INSTANTIATE_TEST_CASE_P(RNNCell, LayerGoldenTest,
+                        ::testing::Values(rnncell_single_step));
index 10b2128..9e9f336 100644 (file)
@@ -197,6 +197,67 @@ static std::unique_ptr<NeuralNetwork> makeStackedLSTMCell() {
   return nn;
 }
 
+static std::unique_ptr<NeuralNetwork> makeSingleRNNCell() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here rnncell is being inserted
+    {"mse", {"name=loss", "input_layers=rnncell_scope/a1"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto rnncell = makeGraph({
+    {"rnncell", {"name=a1", "unit=2"}},
+  });
+
+  nn->addWithReferenceLayers(rnncell, "rnncell_scope", {"input"}, {"a1"},
+                             {"a1"}, ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=true",
+                               "recurrent_input=a1",
+                               "recurrent_output=a1",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
+static std::unique_ptr<NeuralNetwork> makeStackedRNNCell() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here rnncells are being inserted
+    {"mse", {"name=loss", "input_layers=rnncell_scope/a2"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto rnncell = makeGraph({
+    {"rnncell", {"name=a1", "unit=2"}},
+    {"rnncell", {"name=a2", "unit=2", "input_layers=a1"}},
+  });
+
+  nn->addWithReferenceLayers(rnncell, "rnncell_scope", {"input"}, {"a1"},
+                             {"a2"}, ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=true",
+                               "recurrent_input=a1",
+                               "recurrent_output=a2",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 INSTANTIATE_TEST_CASE_P(
   recurrentModels, nntrainerModelTest,
   ::testing::ValuesIn({
@@ -209,6 +270,10 @@ INSTANTIATE_TEST_CASE_P(
     mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeStackedLSTMCell, "lstm_stacked__1",
                  ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleRNNCell, "rnncell_single__1",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedRNNCell, "rnncell_stacked__1",
+                 ModelTestOption::COMPARE_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);