[test] unittest for weight decay
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 27 Jan 2022 06:36:46 +0000 (15:36 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 9 Feb 2022 09:34:12 +0000 (18:34 +0900)
added unittest for weight decay along with fix.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/lstm.cpp
nntrainer/tensor/weight.h
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelTests_v2.py
test/input_gen/recorder_v2.py
test/unittest/layers/unittest_layers_fully_connected.cpp
test/unittest/models/unittest_models.cpp

index 78bbcb3..e696e75 100644 (file)
@@ -200,9 +200,9 @@ void LSTMLayer::finalize(InitLayerContext &context) {
   // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ]
   // -> i, f, g, o
   const TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
-  wt_idx[LSTMParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, weight_decay, "weight_ih", true);
+  wt_idx[LSTMParams::weight_ih] = context.requestWeight(
+    weight_ih_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_ih", true);
   // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i,
   // f, g, o
   const TensorDim weight_hh_dim({unit, NUM_GATE * unit});
@@ -220,12 +220,12 @@ void LSTMLayer::finalize(InitLayerContext &context) {
     } else {
       // bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
       const TensorDim bias_ih_dim({NUM_GATE * unit});
-      wt_idx[LSTMParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, bias_decay, "bias_ih", true);
+      wt_idx[LSTMParams::bias_ih] = context.requestWeight(
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_ih", true);
       // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
       wt_idx[LSTMParams::bias_hh] = context.requestWeight(
-        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
         bias_decay, "bias_hh", true);
     }
   }
@@ -257,14 +257,14 @@ void LSTMLayer::finalize(InitLayerContext &context) {
     const TensorDim reverse_weight_ih_dim({feature_size, NUM_GATE * unit});
     wt_idx[LSTMParams::reverse_weight_ih] = context.requestWeight(
       reverse_weight_ih_dim, weight_initializer, weight_regularizer,
-      weight_regularizer_constant, "reverse_weight_ih", true);
+      weight_regularizer_constant, weight_decay, "reverse_weight_ih", true);
     // reverse_weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE *
     // unit ]
     // -> i, f, g, o
     const TensorDim reverse_weight_hh_dim({unit, NUM_GATE * unit});
     wt_idx[LSTMParams::reverse_weight_hh] = context.requestWeight(
       reverse_weight_hh_dim, weight_initializer, weight_regularizer,
-      weight_regularizer_constant, "reverse_weight_hh", true);
+      weight_regularizer_constant, weight_decay, "reverse_weight_hh", true);
     if (!disable_bias) {
       if (integrate_bias) {
         // reverse_bias_h ( input bias, hidden bias are integrate to 1 bias
@@ -272,20 +272,20 @@ void LSTMLayer::finalize(InitLayerContext &context) {
         const TensorDim reverse_bias_h_dim({NUM_GATE * unit});
         wt_idx[LSTMParams::reverse_bias_h] = context.requestWeight(
           reverse_bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
-          "reverse_bias_h", true);
+          bias_decay, "reverse_bias_h", true);
       } else {
         // reverse_bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
         // i, f, g, o
         const TensorDim reverse_bias_ih_dim({NUM_GATE * unit});
         wt_idx[LSTMParams::reverse_bias_ih] = context.requestWeight(
           reverse_bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
-          "reverse_bias_ih", true);
+          bias_decay, "reverse_bias_ih", true);
         // reverse_bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
         // i, f, g, o
         const TensorDim reverse_bias_hh_dim({NUM_GATE * unit});
         wt_idx[LSTMParams::reverse_bias_hh] = context.requestWeight(
           reverse_bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
-          "reverse_bias_hh", true);
+          bias_decay, "reverse_bias_hh", true);
       }
     }
 
index a8b0525..2679611 100644 (file)
@@ -295,9 +295,9 @@ private:
   std::vector<Tensor *> opt_vars; /**< optimizer variables */
 
   /**
-   * @brief     Apply the gradient to the weight
+   * @brief     Apply the weight decay to the weight
    */
-  void applyWeightDecay() { var->add_i(*var.get(), -decay); }
+  void applyWeightDecay() { grad->add_i(*var.get(), decay); }
 };
 
 } // namespace nntrainer
index 4480c3f..40a7894 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index 1f050ec..9ebd599 100644 (file)
@@ -72,6 +72,36 @@ class MolAttention(torch.nn.Module):
 
         return (output, kappa), loss
 
+class FCRelu(torch.nn.Module):
+    def __init__(self, decay=False):
+        super().__init__()
+        self.fc = torch.nn.Linear(3, 10)
+        self.fc1 = torch.nn.Linear(10, 2)
+        self.loss = torch.nn.MSELoss()
+        self.decay = decay
+
+    def forward(self, inputs, labels):
+        out = torch.relu(self.fc(inputs[0]))
+        out = torch.sigmoid(self.fc1(out))
+        loss = self.loss(out, labels[0])
+        return out, loss
+
+    def getOptimizer(self):
+        if not self.decay:
+            return torch.optim.SGD(self.parameters(), lr=0.1)
+        else:
+            decay_params = []
+            non_decay_params = []
+            for name, params in self.named_parameters():
+                if name == 'fc.weight' or name == 'fc1.bias':
+                    decay_params.append(params)
+                else:
+                    non_decay_params.append(params)
+            return torch.optim.SGD([
+                {'params': non_decay_params},
+                {'params': decay_params, 'weight_decay': 0.9}], lr=0.1)
+
+
 if __name__ == "__main__":
     record_v2(
         ReduceMeanLast(),
@@ -99,4 +129,15 @@ if __name__ == "__main__":
         name="mol_attention",
     )
 
-    # inspect_file("mol_attention_masked.nnmodelgolden")
+    fc_relu_decay = FCRelu(decay=True)
+    record_v2(
+        fc_relu_decay,
+        iteration=2,
+        input_dims=[(3,3)],
+        input_dtype=[float],
+        label_dims=[(3,2)],
+        name="fc_relu_decay",
+        optimizer=fc_relu_decay.getOptimizer()
+    )
+
+    inspect_file("fc_relu_decay.nnmodelgolden")
index 049a33d..5a58b9d 100644 (file)
@@ -62,7 +62,7 @@ def _rand_like(shapes, scale=1, dtype=None):
 # @param label_dims dimensions to record including batch (list of tuple)
 # @param name golden name
 def record_v2(model, iteration, input_dims, label_dims, name, clip=False,
-              input_dtype=None, input_label_reader=None):
+              input_dtype=None, input_label_reader=None, optimizer=None):
     ## file format is as below
     # [<number of iteration(int)> <Iteration> <Iteration>...<Iteration>]
     # Each iteration contains
@@ -74,7 +74,8 @@ def record_v2(model, iteration, input_dims, label_dims, name, clip=False,
     if os.path.isfile(file_name):
         print("Warning: the file %s is being truncated and overwritten" % file_name)
 
-    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
+    if optimizer == None:
+        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
 
     def record_iteration(write_fn):
         if input_label_reader != None:
index afb9aa2..76ad965 100644 (file)
@@ -31,7 +31,11 @@ auto fc_basic_single_batch = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=4"},
   "1:1:1:10", "fc_single_batch.nnlayergolden",
   LayerGoldenTestParamOptions::DEFAULT);
+auto fc_basic_no_decay = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::FullyConnectedLayer>,
+  {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:1:1:10",
+  "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
 
 INSTANTIATE_TEST_CASE_P(FullyConnected, LayerGoldenTest,
-                        ::testing::Values(fc_basic_plain,
-                                          fc_basic_single_batch));
+                        ::testing::Values(fc_basic_plain, fc_basic_single_batch,
+                                          fc_basic_no_decay));
index 5512208..257c884 100644 (file)
@@ -29,6 +29,7 @@ static std::string fc_base = "type = Fully_connected";
 static std::string red_mean_base = "type = reduce_mean";
 static IniSection sgd_base("optimizer", "Type = sgd");
 static IniSection constant_loss("loss", "type = constant_derivative");
+static IniSection act_base("activation", "Type = Activation");
 
 IniWrapper reduce_mean_last("reduce_mean_last",
                             {
@@ -40,6 +41,15 @@ IniWrapper reduce_mean_last("reduce_mean_last",
                               constant_loss,
                             });
 
+IniWrapper fc_relu_decay(
+  "fc_relu_decay",
+  {nn_base + "Loss=mse | batch_size = 3", sgd_base + "learning_rate = 0.1",
+   IniSection("input") + "type=input" + "input_shape = 1:1:3",
+   IniSection("dense") + fc_base + "unit = 10" + "weight_decay=0.9",
+   IniSection("act") + act_base + "Activation = relu",
+   IniSection("dense_1") + fc_base + "unit = 2" + "bias_decay=0.9",
+   IniSection("act_1") + act_base + "Activation = sigmoid"});
+
 static std::unique_ptr<NeuralNetwork> makeMolAttention() {
   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
   nn->setProperty({"batch_size=3"});
@@ -96,6 +106,8 @@ INSTANTIATE_TEST_CASE_P(
                  ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked",
                  ModelTestOption::COMPARE_RUN_V2),
+    mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_,
+                 ModelTestOption::COMPARE_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);