[layers] Add trainable feature
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 3 Jul 2020 07:08:19 +0000 (16:08 +0900)
committerParichay Kapoor <pk.kapoor@samsung.com>
Mon, 6 Jul 2020 04:14:37 +0000 (13:14 +0900)
Add trainable feature for each layer which allows certain layers to just not train without affecting others

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/include/layer.h
nntrainer/src/conv2d_layer.cpp
nntrainer/src/fc_layer.cpp

index dc9d876..a0b038f 100644 (file)
@@ -127,7 +127,8 @@ public:
       bn_follow(false),
       weight_decay(),
       weight_ini_type(WEIGHT_UNKNOWN),
-      flatten(false) {}
+      flatten(false),
+      trainable(true) {}
 
   /**
    * @brief     Destructor of Layer Class
@@ -297,6 +298,11 @@ public:
    */
   int setCost(CostType c);
 
+  /**
+   * @brief     set trainable for this layer
+   * @param[in] train to enable/disable train
+   */
+  void setTrainable(bool train) { trainable = train; }
 
   /**
    * @brief     get gradients
@@ -434,6 +440,11 @@ protected:
    */
   bool flatten;
 
+  /*
+   * @brief     making this false will skip updating this layer variables
+   */
+  bool trainable;
+
   /**
    * @brief     Gradient for the weights in this layer
    * @note      The order of gradients should match the order in weights
index a684248..8b59a35 100644 (file)
@@ -191,18 +191,20 @@ Tensor Conv2DLayer::backwarding(Tensor derivative, int iteration) {
 
   gradients.clear();
 
-  //  Update K / bias
-  for (unsigned int i = 0; i < filter_size; ++i) {
-    Tensor djdw = delK[i]
-                    .chain()
-                    .applyIf(this->isWeightDecayL2Norm(), _LIFT(add_i),
-                             filters[i], weight_decay.lambda)
-                    .run();
-
-    gradients.push_back(djdw);
-    gradients.push_back(delBias[i]);
+  if (trainable) {
+    //  Update K / bias
+    for (unsigned int i = 0; i < filter_size; ++i) {
+      Tensor djdw = delK[i]
+        .chain()
+        .applyIf(this->isWeightDecayL2Norm(), _LIFT(add_i),
+            filters[i], weight_decay.lambda)
+        .run();
+
+      gradients.push_back(djdw);
+      gradients.push_back(delBias[i]);
+    }
+    opt.apply_gradients(weights, gradients, iteration);
   }
-  opt.apply_gradients(weights, gradients, iteration);
 
   return strip_pad(ret, padding);
 }
index 305d774..84b736d 100644 (file)
@@ -178,11 +178,13 @@ Tensor FullyConnectedLayer::backwarding(Tensor derivative, int iteration) {
                            weight_decay.lambda)
                   .run();
 
-  gradients.clear();
-  gradients.push_back(djdw);
-  gradients.push_back(djdb);
+  if (trainable) {
+    gradients.clear();
+    gradients.push_back(djdw);
+    gradients.push_back(djdb);
 
-  opt.apply_gradients(weights, gradients, iteration);
+    opt.apply_gradients(weights, gradients, iteration);
+  }
 
   return ret;
 }