Move weight_decay handling from opt to layer
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 24 Jun 2020 02:08:28 +0000 (11:08 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 24 Jun 2020 04:50:45 +0000 (13:50 +0900)
**Changes proposed in this PR:**
- remove weight_decay from `Optimizer::calculate` signature
- apply weight decay to fc_layer.cpp

please note that conv2d_layer::backwarding also need to handle weight
decay after this PR is merged.

Resolves #213

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/include/layer.h
nntrainer/include/optimizer.h
nntrainer/src/fc_layer.cpp
nntrainer/src/optimizer.cpp

index 38d176b..e29b53a 100644 (file)
@@ -330,11 +330,20 @@ public:
   };
 
 protected:
+
+/**
+ * @brief        check if current layer's weight decay type is l2norm
+ * @return       bool is weightdecay type is L2 Norm
+ */
+  bool isWeightDecayL2Norm() {
+    return weight_decay.type == WeightDecayType::l2norm;
+  }
   /**
    * @brief     Input Tensor
    */
   Tensor input;
 
+
   /**
    * @brief     Hidden Layer Tensor which store the
    *            forwading result
index a71f440..76fb1f0 100644 (file)
@@ -157,14 +157,13 @@ public:
    * @brief     calculate optimizer and Update Weight & Bais
    * @param[in] dJdW Weight derivative
    * @param[in] dJdB Bias derivative
-   * @param[in] Weight Weight Tensor
-   * @param[in] Bias Bais Tensor
+   * @param[in/out] Weight Weight Tensor
+   * @param[in/out] Bias Bias Tensor
    * @param[in] iteration nth epoch number
    * @param[in] init_zero bool it is true if bias sets zero.
-   * @param[in] weight_decay weight decay type & lambda
    */
-  void calculate(Tensor &djdw, Tensor &djdb, Tensor &weight, Tensor &bias,
-                 int iteration, bool init_zero, WeightDecayParam weight_decay);
+  void calculate(const Tensor &djdw, const Tensor &djdb, Tensor &weight,
+                 Tensor &bias, int iteration, bool init_zero);
 
   /**
    * @brief     Property Enumeration
index 10cc8be..6680c4c 100644 (file)
@@ -161,10 +161,11 @@ void FullyConnectedLayer::copy(std::shared_ptr<Layer> l) {
 
 Tensor FullyConnectedLayer::backwarding(Tensor derivative, int iteration) {
   Tensor ret = derivative.dot(weight.transpose("0:2:1"));
-  Tensor djdw = input.transpose("0:2:1").dot(derivative);
-
-  opt.calculate(djdw, derivative, weight, bias, iteration, this->init_zero,
-                weight_decay);
+  Tensor djdw = input.chain()
+                  .transpose("0:2:1")
+                  .dot(derivative)
+                  .applyIf(this->isWeightDecayL2Norm(), _LIFT(add_i), weight, weight_decay.lambda)
+                  .run();
 
   return ret;
 }
index 01bc8ef..dbb355f 100644 (file)
@@ -77,26 +77,16 @@ int Optimizer::initialize(TensorDim d, bool set_tensor) {
   return status;
 }
 
-void Optimizer::calculate(Tensor &djdw, Tensor &djdb, Tensor &weight,
-                          Tensor &bias, int iteration, bool init_zero,
-                          WeightDecayParam weight_decay) {
+void Optimizer::calculate(const Tensor &djdw, const Tensor &djdb,
+                          Tensor &weight, Tensor &bias, int iteration,
+                          bool init_zero) {
   Tensor djdwAvg, djdbAvg;
-  if (weight_decay.type == WeightDecayType::l2norm) {
-    djdw = djdw.add(weight.multiply(weight_decay.lambda));
-  }
-
   float ll = popt.learning_rate;
   if (popt.decay_steps != -1) {
     ll = ll * pow(popt.decay_rate, (iteration / popt.decay_steps));
   }
 
-  bool isL2norm = weight_decay.type == WeightDecayType::l2norm;
-
-  djdwAvg = djdw.chain()
-              .applyIf(isL2norm, _LIFT(add_i), weight, weight_decay.lambda)
-              .average()
-              .run();
-
+  djdwAvg = djdw.average();
   djdbAvg = djdb.average();
 
   switch (type) {