[weight-decay] Bug fix for weight decay with adam
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 18 Feb 2022 11:01:58 +0000 (20:01 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Feb 2022 13:44:22 +0000 (22:44 +0900)
weight decay should be applied before calling the optimizer
this was not detected earlier as it was tested with sgd

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/models/neuralnet.cpp
nntrainer/tensor/weight.h

index 85180c6..ac54e38 100644 (file)
@@ -299,6 +299,7 @@ void NeuralNetwork::backwarding(int iteration) {
       model_graph.applyGradients(
         node.get(), [iteration, opt_ = opt.get()](Weight &w) {
           w.calcRegularizationGradient();
+          w.calcWeightDecayGradient();
           RunOptimizerContext opt_context(&w, iteration);
           opt_->applyGradient(opt_context);
         });
index 2679611..988f7de 100644 (file)
@@ -234,7 +234,7 @@ public:
   }
 
   /**
-   * @brief     Calculate gradient from the regularizaiton of the weight
+   * @brief     Calculate gradient from the regularization of the weight
    */
   void calcRegularizationGradient() {
     if (isWeightRegularizerL2Norm())
@@ -242,16 +242,19 @@ public:
   }
 
   /**
-   * @brief     Apply the gradient to the weight
+   * @brief     Calculate gradient from the decay of the weight
    */
-  void applyGradient(double lr) {
+  void calcWeightDecayGradient() {
     if (isWeightDecay())
       applyWeightDecay();
-
-    var->add_i(*grad.get(), -lr);
   }
 
   /**
+   * @brief     Apply the gradient to the weight
+   */
+  void applyGradient(double lr) { var->add_i(*grad.get(), -lr); }
+
+  /**
    * @brief Check if the gradient is supposed to be clipped by global norm with
    * the given max_norm value
    *