[unittest] Implement grucell unittest
authorhyeonseok lee <hs89.lee@samsung.com>
Mon, 22 Nov 2021 06:49:36 +0000 (15:49 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 25 Nov 2021 10:19:19 +0000 (19:19 +0900)
 - Verify grucell with tensorflow by layer unittest
 - Added grucell model unittest to verify multi unit/timestep, stacked, unroll situation with pytorch
 - Todo: Finds other way without copying when convert gate order of pytorch grucell

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
api/ccapi/include/layer.h
jni/Android.mk
nntrainer/app_context.cpp
nntrainer/compiler/recurrent_realizer.cpp
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer_v2.py
test/unittest/layers/meson.build
test/unittest/layers/unittest_layers_grucell.cpp [new file with mode: 0644]
test/unittest/models/unittest_models_recurrent.cpp

index a4fe4ea..bad2f79 100644 (file)
@@ -70,6 +70,7 @@ enum LayerType {
   LAYER_RESHAPE,                           /**< Reshape Layer type */
   LAYER_RNNCELL,                           /**< RNN Cell Layer type */
   LAYER_LSTMCELL,                          /**< LSTM Cell Layer type */
+  LAYER_GRUCELL,                           /**< GRU Cell Layer type */
   LAYER_LOSS_MSE = 500,             /**< Mean Squared Error Loss Layer type */
   LAYER_LOSS_CROSS_ENTROPY_SIGMOID, /**< Cross Entropy with Sigmoid Loss Layer
                                        type */
@@ -344,6 +345,14 @@ GRU(const std::vector<std::string> &properties = {}) {
 }
 
 /**
+ * @brief Helper function to create GRUCell layer
+ */
+inline std::unique_ptr<Layer>
+GRUCell(const std::vector<std::string> &properties = {}) {
+  return createLayer(LayerType::LAYER_GRUCELL, properties);
+}
+
+/**
  * @brief Helper function to create DropOut layer
  */
 inline std::unique_ptr<Layer>
index 7f3c11c..5b0ecea 100644 (file)
@@ -173,6 +173,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/lstm.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/lstmcell.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/gru.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/grucell.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/time_dist.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/dropout.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/permute_layer.cpp \
index 309c01d..1b6a054 100644 (file)
@@ -44,6 +44,7 @@
 #include <fc_layer.h>
 #include <flatten_layer.h>
 #include <gru.h>
+#include <grucell.h>
 #include <input_layer.h>
 #include <lstm.h>
 #include <lstmcell.h>
@@ -250,6 +251,8 @@ static void add_default_object(AppContext &ac) {
                      LayerType::LAYER_SPLIT);
   ac.registerFactory(nntrainer::createLayer<GRULayer>, GRULayer::type,
                      LayerType::LAYER_GRU);
+  ac.registerFactory(nntrainer::createLayer<GRUCellLayer>, GRUCellLayer::type,
+                     LayerType::LAYER_GRUCELL);
   ac.registerFactory(nntrainer::createLayer<PermuteLayer>, PermuteLayer::type,
                      LayerType::LAYER_PERMUTE);
   ac.registerFactory(nntrainer::createLayer<DropOutLayer>, DropOutLayer::type,
index a5c3923..02df5ab 100644 (file)
@@ -13,6 +13,7 @@
 #include <recurrent_realizer.h>
 
 #include <common_properties.h>
+#include <grucell.h>
 #include <input_layer.h>
 #include <layer_node.h>
 #include <lstm.h>
@@ -130,7 +131,8 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step,
   auto is_recurrent_type = [](LayerNode *node) {
     return node->getType() == RNNCellLayer::type ||
            node->getType() == LSTMLayer::type ||
-           node->getType() == LSTMCellLayer::type;
+           node->getType() == LSTMCellLayer::type ||
+           node->getType() == GRUCellLayer::type;
   };
 
   if (is_recurrent_type(node)) {
index cd19c8a..dea2ea1 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index e4d58d1..cbbc884 100644 (file)
@@ -88,6 +88,36 @@ class LSTMStacked(torch.nn.Module):
         loss = self.loss(ret, labels[0])
         return ret, loss
 
+class GRUCellStacked(torch.nn.Module):
+    def __init__(self, unroll_for=2, num_gru=1):
+        super().__init__()
+        self.input_size = self.hidden_size = 2
+        self.grus = torch.nn.ModuleList(
+            [
+                torch.nn.GRUCell(self.input_size, self.hidden_size, bias=True)
+                for _ in range(num_gru)
+            ]
+        )
+        for gru in self.grus:
+            gru.bias_hh.data.fill_(0.0)
+            gru.bias_hh.requires_grad=False
+        self.unroll_for = unroll_for
+        self.loss = torch.nn.MSELoss()
+
+    def forward(self, inputs, labels):
+        hs = [torch.zeros_like(inputs[0]) for _ in self.grus]
+        out = inputs[0]
+        ret = []
+        for _ in range(self.unroll_for):
+            for i, (gru, h) in enumerate(zip(self.grus, hs)):
+                hs[i] = gru(out, h)
+                out = hs[i]
+            ret.append(out)
+
+        ret = torch.stack(ret, dim=1)
+        loss = self.loss(ret, labels[0])
+        return ret, loss
+
 if __name__ == "__main__":
     record_v2(
         FCUnroll(unroll_for=5),
@@ -137,4 +167,20 @@ if __name__ == "__main__":
         name="lstm_stacked",
     )
 
+    record_v2(
+        GRUCellStacked(unroll_for=2, num_gru=1),
+        iteration=2,
+        input_dims=[(3, 2)],
+        label_dims=[(3, 2, 2)],
+        name="grucell_single",
+    )
+
+    record_v2(
+        GRUCellStacked(unroll_for=2, num_gru=2),
+        iteration=2,
+        input_dims=[(3, 2)],
+        label_dims=[(3, 2, 2)],
+        name="grucell_stacked",
+    )
+
     # inspect_file("lstm_single.nnmodelgolden")
index a4f6aba..db1dc36 100644 (file)
@@ -54,16 +54,40 @@ def bn1d_translate(model):
 
 
 @register_for_((torch.nn.RNNCell, torch.nn.LSTMCell))
-def lstm_translate(model):
+def rnn_lstm_translate(model):
     params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
     bias = ("bias", params[2][1] + params[3][1])
-    # hidden, input -> input, hidden
+    # [hidden, input] -> [input, hidden]
     def transpose_(weight):
         return (weight[0], weight[1].transpose(1, 0))
 
     new_params = [transpose_(params[0]), transpose_(params[1]), bias]
     yield from new_params
 
+@register_for_((torch.nn.GRUCell))
+def gru_translate(model):
+    params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
+    bias = ("bias", params[2][1] + params[3][1])
+
+    # [hidden, input] -> [input, hidden]
+    def transpose_(weight):
+        return (weight[0], weight[1].transpose(1, 0))
+
+    # resetgate, inputgate, newgate -> inputgate, resetgate, newgate
+    def reorder_weight(param):
+        if (param[1].dim() == 2):
+            hidden_size = int(param[1].shape[1] / 3)
+        else:
+            hidden_size = int(param[1].shape[0] / 3)
+
+        weight = param[1].hsplit(3)
+        return (param[0], torch.hstack((weight[1], weight[0], weight[2])))
+
+    transposed_params = [transpose_(params[0]), transpose_(params[1]), bias]
+    new_params = [reorder_weight(transposed_params[0]), reorder_weight(transposed_params[1]), reorder_weight(transposed_params[2])]
+
+    yield from new_params
+
 def translate(model):
     for child in model.children():
         for registered_classes, fn in handler_book:
index df37a1b..1925c9c 100644 (file)
@@ -47,6 +47,7 @@ test_target = [
   'unittest_layers_lstm.cpp',
   'unittest_layers_lstmcell.cpp',
   'unittest_layers_gru.cpp',
+  'unittest_layers_grucell.cpp',
   'unittest_layers_preprocess_flip.cpp',
   'unittest_layers_split.cpp',
   'unittest_layers_embedding.cpp',
diff --git a/test/unittest/layers/unittest_layers_grucell.cpp b/test/unittest/layers/unittest_layers_grucell.cpp
new file mode 100644 (file)
index 0000000..23a2361
--- /dev/null
@@ -0,0 +1,33 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+ *
+ * @file unittest_layers_grucell.cpp
+ * @date 09 November 2021
+ * @brief GRUCell Layer Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author hyeonseok lee <hs89.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <grucell.h>
+#include <layers_common_tests.h>
+
+auto semantic_grucell = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::GRUCellLayer>,
+  nntrainer::GRUCellLayer::type, {"unit=1", "max_timestep=1", "timestep=0"}, 0,
+  false, 1);
+
+INSTANTIATE_TEST_CASE_P(GRUCell, LayerSemantics,
+                        ::testing::Values(semantic_grucell));
+
+auto grucell_single_step = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::GRUCellLayer>,
+  {"unit=5", "max_timestep=1", "timestep=0"}, "3:1:1:7",
+  "gru_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+
+INSTANTIATE_TEST_CASE_P(GRUCell, LayerGoldenTest,
+                        ::testing::Values(grucell_single_step));
index ef32cfa..75d044f 100644 (file)
@@ -272,6 +272,67 @@ static std::unique_ptr<NeuralNetwork> makeStackedRNNCell() {
   return nn;
 }
 
+static std::unique_ptr<NeuralNetwork> makeSingleGRUCell() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here grucell is being inserted
+    {"mse", {"name=loss", "input_layers=grucell_scope/a1"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto grucell = makeGraph({
+    {"grucell", {"name=a1", "unit=2"}},
+  });
+
+  nn->addWithReferenceLayers(grucell, "grucell_scope", {"input"}, {"a1"},
+                             {"a1"}, ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=true",
+                               "recurrent_input=a1",
+                               "recurrent_output=a1",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
+static std::unique_ptr<NeuralNetwork> makeStackedGRUCell() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here grucells are being inserted
+    {"mse", {"name=loss", "input_layers=grucell_scope/a2"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto grucell = makeGraph({
+    {"grucell", {"name=a1", "unit=2"}},
+    {"grucell", {"name=a2", "unit=2", "input_layers=a1"}},
+  });
+
+  nn->addWithReferenceLayers(grucell, "grucell_scope", {"input"}, {"a1"},
+                             {"a2"}, ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=true",
+                               "recurrent_input=a1",
+                               "recurrent_output=a2",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 INSTANTIATE_TEST_CASE_P(
   recurrentModels, nntrainerModelTest,
   ::testing::ValuesIn({
@@ -290,6 +351,10 @@ INSTANTIATE_TEST_CASE_P(
                  ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeStackedRNNCell, "rnncell_stacked__1",
                  ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleGRUCell, "grucell_single__1",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedGRUCell, "grucell_stacked__1",
+                 ModelTestOption::COMPARE_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);