[layer] Unittests for reduce mean layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 25 Nov 2021 13:46:25 +0000 (22:46 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 2 Dec 2021 06:54:18 +0000 (15:54 +0900)
This patch adds unittests for reduce mean layer with bug fix.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
api/ccapi/include/layer.h
jni/Android.mk
nntrainer/app_context.cpp
nntrainer/layers/common_properties.h
nntrainer/layers/reduce_mean_layer.cpp
nntrainer/layers/reduce_mean_layer.h
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelTests_v2.py [new file with mode: 0644]
test/unittest/models/meson.build
test/unittest/models/unittest_models.cpp [new file with mode: 0644]

index 62d62b6..d49ca6f 100644 (file)
@@ -72,6 +72,7 @@ enum LayerType {
   LAYER_RNNCELL,                           /**< RNN Cell Layer type */
   LAYER_LSTMCELL,                          /**< LSTM Cell Layer type */
   LAYER_GRUCELL,                           /**< GRU Cell Layer type */
+  LAYER_REDUCE_MEAN,                       /**< Reduce mean Layer type */
   LAYER_LOSS_MSE = 500,             /**< Mean Squared Error Loss Layer type */
   LAYER_LOSS_CROSS_ENTROPY_SIGMOID, /**< Cross Entropy with Sigmoid Loss Layer
                                        type */
@@ -394,6 +395,14 @@ Permute(const std::vector<std::string> &properties = {}) {
 }
 
 /**
+ * @brief Helper function to create Reduce Mean Layer
+ */
+inline std::unique_ptr<Layer>
+ReduceMean(const std::vector<std::string> &properties = {}) {
+  return createLayer(LayerType::LAYER_REDUCE_MEAN, properties);
+}
+
+/**
  * @brief Helper function to create activation layer
  */
 inline std::unique_ptr<Layer>
index aca7008..16a19cd 100644 (file)
@@ -184,6 +184,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/split_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/common_properties.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/layer_impl.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/reduce_mean_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/graph/network_graph.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/graph/graph_core.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_context.cpp \
index 409f5e4..7a52e70 100644 (file)
@@ -59,6 +59,7 @@
 #include <preprocess_flip_layer.h>
 #include <preprocess_l2norm_layer.h>
 #include <preprocess_translate_layer.h>
+#include <reduce_mean_layer.h>
 #include <rnn.h>
 #include <rnncell.h>
 #include <split_layer.h>
@@ -262,6 +263,8 @@ static void add_default_object(AppContext &ac) {
                      AttentionLayer::type, LayerType::LAYER_ATTENTION);
   ac.registerFactory(nntrainer::createLayer<MoLAttentionLayer>,
                      MoLAttentionLayer::type, LayerType::LAYER_MOL_ATTENTION);
+  ac.registerFactory(nntrainer::createLayer<ReduceMeanLayer>,
+                     ReduceMeanLayer::type, LayerType::LAYER_REDUCE_MEAN);
 
 #ifdef ENABLE_NNSTREAMER_BACKBONE
   ac.registerFactory(nntrainer::createLayer<NNStreamerLayer>,
index ea50276..ea136be 100644 (file)
@@ -349,6 +349,12 @@ public:
 class ConcatDimension : public SplitDimension {};
 
 /**
+ * @brief ReduceDimension property, dimension along which to reduce the input
+ *
+ */
+class ReduceDimension : public SplitDimension {};
+
+/**
  * @brief FilterSize property, filter size is used to measure how many filters
  * are there
  *
index c73f6fb..7710691 100644 (file)
@@ -27,9 +27,11 @@ void ReduceMeanLayer::finalize(InitLayerContext &context) {
   const TensorDim &in_dim = context.getInputDimensions()[0];
   TensorDim out_dim = in_dim;
 
-  /** if reduce axis is not provided, reduction is performed across all the
-   * dimensions */
-  auto &reduce_axis = std::get<props::Axis>(reduce_mean_props);
+  /**
+   * if reduce axis is not provided, reduction is performed across all the
+   * dimensions except the batch
+   */
+  auto &reduce_axis = std::get<props::ReduceDimension>(reduce_mean_props);
   if (reduce_axis.empty()) {
     out_dim = TensorDim({1, 1, 1, 1});
   }
@@ -39,10 +41,10 @@ void ReduceMeanLayer::finalize(InitLayerContext &context) {
 }
 
 void ReduceMeanLayer::forwarding(RunLayerContext &context, bool training) {
-  auto &reduce_axis = std::get<props::Axis>(reduce_mean_props);
+  auto &reduce_axis = std::get<props::ReduceDimension>(reduce_mean_props);
   if (reduce_axis.empty()) {
     context.getInput(SINGLE_INOUT_IDX)
-      .average(context.getOutput(SINGLE_INOUT_IDX));
+      .average({1, 2, 3}, context.getOutput(SINGLE_INOUT_IDX));
   } else {
     context.getInput(SINGLE_INOUT_IDX)
       .average(reduce_axis, context.getOutput(SINGLE_INOUT_IDX));
@@ -50,19 +52,14 @@ void ReduceMeanLayer::forwarding(RunLayerContext &context, bool training) {
 }
 
 void ReduceMeanLayer::calcDerivative(RunLayerContext &context) {
-  auto &deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
-  auto &ret_deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
+  auto &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
+  auto &ret_deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
 
   unsigned int div = ret_deriv.size() / deriv.size();
-  auto &reduce_axis = std::get<props::Axis>(reduce_mean_props);
 
-  if (reduce_axis.empty()) {
-    ret_deriv.setValue(deriv.getValue(0));
-  } else {
-    /** TODO: optimize this by supporting broadcast in copy */
-    ret_deriv.setZero();
-    ret_deriv.add_i(deriv);
-  }
+  /** TODO: optimize this by supporting broadcast in copy */
+  ret_deriv.setZero();
+  ret_deriv.add_i(deriv);
 
   ret_deriv.divide_i(div);
 }
index 75b5171..6cb54f7 100644 (file)
@@ -15,6 +15,8 @@
 #define __REDUCE_MEAN_LAYER_H__
 #ifdef __cplusplus
 
+#include <common_properties.h>
+#include <layer_context.h>
 #include <layer_devel.h>
 
 namespace nntrainer {
@@ -86,7 +88,7 @@ public:
 
 private:
   /** TODO: support scalar multiplier to simulate reduce_sum */
-  std::tuple<props::Axis>
+  std::tuple<props::ReduceDimension>
     reduce_mean_props; /**< reduce_mean properties : axis to reduce along */
 };
 
index c2a657f..7889c91 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py
new file mode 100644 (file)
index 0000000..f090382
--- /dev/null
@@ -0,0 +1,35 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+##
+# Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+#
+# @file genModelTests_v2.py
+# @date 25 November 2021
+# @brief Generate model tcs
+# @author Parichay Kapoor <pk.kapoor@samsung.com>
+
+from recorder_v2 import record_v2, inspect_file
+import torch
+
+class ReduceMeanLast(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.fc = torch.nn.Linear(2, 7)
+        self.loss = torch.nn.Identity()
+
+    def forward(self, inputs, labels):
+        out = self.fc(inputs[0])
+        out = torch.mean(out, dim=-1)
+        loss = self.loss(torch.sum(out))
+        return out, loss
+
+if __name__ == "__main__":
+    record_v2(
+        ReduceMeanLast(),
+        iteration=2,
+        input_dims=[(3, 2,)],
+        label_dims=[(3, 1,)],
+        name="reduce_mean_last",
+    )
+
+    # inspect_file("lstm_single.nnmodelgolden")
index 3864d3b..4ac104e 100644 (file)
@@ -6,7 +6,8 @@ models_targets = [
   'models_test_utils.cpp',
   'models_golden_test.cpp',
   'unittest_models_recurrent.cpp',
-  'unittest_models_multiout.cpp'
+  'unittest_models_multiout.cpp',
+  'unittest_models.cpp',
 ]
 
 test_target += models_targets
diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp
new file mode 100644 (file)
index 0000000..a2384da
--- /dev/null
@@ -0,0 +1,51 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file unittest_models_v2.cpp
+ * @date 25 Nov 2021
+ * @brief unittest models for v2 version
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include <ini_wrapper.h>
+#include <neuralnet.h>
+#include <nntrainer_test_util.h>
+
+#include <models_golden_test.h>
+
+using namespace nntrainer;
+
+static inline constexpr const int NOT_USED_ = 1;
+
+static IniSection nn_base("model", "type = NeuralNetwork");
+static std::string fc_base = "type = Fully_connected";
+static std::string red_mean_base = "type = reduce_mean";
+static IniSection sgd_base("optimizer", "Type = sgd");
+static IniSection constant_loss("loss", "type = constant_derivative");
+
+IniWrapper reduce_mean_last("reduce_mean_last",
+                            {
+                              nn_base + "batch_size=3",
+                              sgd_base + "learning_rate=0.1",
+                              IniSection("fc_1") + fc_base +
+                                "unit=7 | input_shape=1:1:2",
+                              IniSection("red_mean") + red_mean_base + "axis=3",
+                              constant_loss,
+                            });
+
+INSTANTIATE_TEST_CASE_P(
+  model, nntrainerModelTest,
+  ::testing::ValuesIn({
+    mkModelIniTc(reduce_mean_last, DIM_UNUSED, NOT_USED_,
+                 ModelTestOption::COMPARE_V2),
+  }),
+  [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
+    return std::get<1>(info.param);
+  });