[ Refactor ] Layer Method for weight initialization
authorjijoong.moon <jijoong.moon@samsung.com>
Tue, 2 Jun 2020 01:31:37 +0000 (10:31 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 5 Jun 2020 01:09:54 +0000 (10:09 +0900)
Cannot use weight initiatlization method for other layer.
In this PR, Move it into Layer class to use for other layer.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/include/layer.h
nntrainer/src/fc_layer.cpp
nntrainer/src/layer.cpp
nntrainer/src/neuralnet.cpp

index 0fbca96..53e2be1 100644 (file)
@@ -194,18 +194,54 @@ public:
    */
   virtual void copy(std::shared_ptr<Layer> l) = 0;
 
-  void setBNfallow(bool ok) { this->bn_fallow = ok; }
+  /**
+   * @brief     set Batch Normalization Layer followed
+   * @param[in] ok true/false
+   */
+  void setBNfollow(bool ok) { this->bn_follow = ok; }
 
+  /**
+   * @brief     check hyper parameter for the layer
+   * @retval #ML_ERROR_NONE Successful.
+   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
+   */
   int checkValidation();
 
+  /**
+   * @brief     set weight decay parameters
+   * @param[in] w struct for weight decay
+   */
   void setWeightDecay(WeightDecayParam w) { weight_decay = w; }
 
+  /**
+   * @brief  get Tensor Dimension
+   * @retval TensorDim Tensor Dimension
+   */
   TensorDim &getTensorDim() { return dim; }
 
+  /**
+   * @brief  set if this is last layer of Network
+   * @param[in] last true/false
+   */
   void setLast(bool last) { last_layer = last; }
 
+  /**
+   * @brief  set Weight Initialization Type
+   * @param[in] wini WeightIniType
+   */
   void setWeightInit(WeightIniType wini) { weight_ini_type = wini; }
 
+  /**
+   * @brief  initialize Weight
+   * @param[in] width width of Tensor
+   * @param[in] height height of Tensor
+   * @param[in] init_type Weight Initialization Type
+   * @param[out] status Status
+   * @retval Tensor Initialized Tensor
+   */
+  Tensor initializeWeight(unsigned int width, unsigned int height,
+                          WeightIniType init_type, int &status);
+
 protected:
   /**
    * @brief     Input Tensor
@@ -255,7 +291,7 @@ protected:
 
   ActiType activation_type;
 
-  bool bn_fallow;
+  bool bn_follow;
 
   WeightDecayParam weight_decay;
 
index b31c411..080d93d 100644 (file)
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <parse_util.h>
-#include <random>
 #include <util_func.h>
 
 namespace nntrainer {
 
-static auto rng = [] {
-  std::mt19937 rng;
-  rng.seed(std::random_device()());
-  return rng;
-}();
-
-template <typename... Args> static void RandNormal(Tensor &w, Args &&... args) {
-  std::normal_distribution<float> dist(std::forward<Args>(args)...);
-  unsigned int width = w.getWidth();
-  unsigned int height = w.getHeight();
-
-  for (unsigned int i = 0; i < width; ++i) {
-    for (unsigned int j = 0; j < height; ++j) {
-      w.setValue(0, j, i, dist(rng));
-    }
-  }
-}
-
-template <typename... Args>
-static void RandUniform(Tensor &w, Args &&... args) {
-  std::uniform_real_distribution<float> dist(std::forward<Args>(args)...);
-  unsigned int width = w.getWidth();
-  unsigned int height = w.getHeight();
-
-  for (unsigned int i = 0; i < width; ++i) {
-    for (unsigned int j = 0; j < height; ++j) {
-      w.setValue(0, j, i, dist(rng));
-    }
-  }
-}
-
-static Tensor weightInitialization(unsigned int width, unsigned int height,
-                                   WeightIniType init_type, int &status) {
-
-  Tensor w = Tensor(height, width);
-
-  if (init_type == WEIGHT_UNKNOWN) {
-    ml_logw("Warning: Weight Initalization Type is not set. "
-            "WEIGHT_XAVIER_NORMAL is used by default");
-    init_type = WEIGHT_XAVIER_NORMAL;
-  }
-
-  switch (init_type) {
-  case WEIGHT_LECUN_NORMAL:
-    RandNormal(w, 0, sqrt(1.0 / height));
-    break;
-  case WEIGHT_XAVIER_NORMAL:
-    RandNormal(w, 0, sqrt(2.0 / (width + height)));
-    break;
-  case WEIGHT_HE_NORMAL:
-    RandNormal(w, 0, sqrt(2.0 / (height)));
-    break;
-  case WEIGHT_LECUN_UNIFORM:
-    RandUniform(w, -1.0 * sqrt(1.0 / height), sqrt(1.0 / height));
-    break;
-  case WEIGHT_XAVIER_UNIFORM:
-    RandUniform(w, -1.0 * sqrt(6.0 / (height + width)),
-                sqrt(6.0 / (height + width)));
-    break;
-  case WEIGHT_HE_UNIFORM:
-    RandUniform(w, -1.0 * sqrt(6.0 / (height)), sqrt(6.0 / (height)));
-    break;
-  default:
-    break;
-  }
-  return w;
-}
-
 int FullyConnectedLayer::initialize(bool last) {
   int status = ML_ERROR_NONE;
   if (dim.batch() <= 0 || dim.height() <= 0 || dim.width() <= 0) {
@@ -109,8 +40,7 @@ int FullyConnectedLayer::initialize(bool last) {
   this->last_layer = last;
 
   bias = Tensor(1, dim.width());
-  weight =
-    weightInitialization(dim.width(), dim.height(), weight_ini_type, status);
+  weight = initializeWeight(dim.width(), dim.height(), weight_ini_type, status);
   NN_RETURN_STATUS();
 
   if (init_zero) {
@@ -200,7 +130,7 @@ Tensor FullyConnectedLayer::forwarding(Tensor in, int &status) {
   input = in;
   hidden = input.dot(weight).add(bias);
 
-  if (this->bn_fallow)
+  if (this->bn_follow)
     return hidden;
 
   if (activation_type == ACT_SOFTMAX) {
index 2304ce3..f3f6f90 100644 (file)
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <parse_util.h>
+#include <random>
 #include <util_func.h>
 
 namespace nntrainer {
 
+static auto rng = [] {
+  std::mt19937 rng;
+  rng.seed(std::random_device()());
+  return rng;
+}();
+
+template <typename... Args> static void RandNormal(Tensor &w, Args &&... args) {
+  std::normal_distribution<float> dist(std::forward<Args>(args)...);
+  unsigned int width = w.getWidth();
+  unsigned int height = w.getHeight();
+
+  for (unsigned int i = 0; i < width; ++i) {
+    for (unsigned int j = 0; j < height; ++j) {
+      w.setValue(0, j, i, dist(rng));
+    }
+  }
+}
+
+template <typename... Args>
+static void RandUniform(Tensor &w, Args &&... args) {
+  std::uniform_real_distribution<float> dist(std::forward<Args>(args)...);
+  unsigned int width = w.getWidth();
+  unsigned int height = w.getHeight();
+
+  for (unsigned int i = 0; i < width; ++i) {
+    for (unsigned int j = 0; j < height; ++j) {
+      w.setValue(0, j, i, dist(rng));
+    }
+  }
+}
+
 Layer::Layer() {
   type = LAYER_UNKNOWN;
   activation_type = ACT_UNKNOWN;
@@ -36,7 +68,7 @@ Layer::Layer() {
   init_zero = false;
   activation = NULL;
   activation_prime = NULL;
-  bn_fallow = false;
+  bn_follow = false;
   weight_decay.type = WeightDecayType::unknown;
   weight_decay.lambda = 0.0;
   weight_ini_type = WEIGHT_UNKNOWN;
@@ -92,4 +124,42 @@ int Layer::checkValidation() {
   }
   return status;
 }
+
+Tensor Layer::initializeWeight(unsigned int width, unsigned int height,
+                               WeightIniType init_type, int &status) {
+
+  Tensor w = Tensor(height, width);
+
+  if (init_type == WEIGHT_UNKNOWN) {
+    ml_logw("Warning: Weight Initalization Type is not set. "
+            "WEIGHT_XAVIER_NORMAL is used by default");
+    init_type = WEIGHT_XAVIER_NORMAL;
+  }
+
+  switch (init_type) {
+  case WEIGHT_LECUN_NORMAL:
+    RandNormal(w, 0, sqrt(1.0 / height));
+    break;
+  case WEIGHT_XAVIER_NORMAL:
+    RandNormal(w, 0, sqrt(2.0 / (width + height)));
+    break;
+  case WEIGHT_HE_NORMAL:
+    RandNormal(w, 0, sqrt(2.0 / (height)));
+    break;
+  case WEIGHT_LECUN_UNIFORM:
+    RandUniform(w, -1.0 * sqrt(1.0 / height), sqrt(1.0 / height));
+    break;
+  case WEIGHT_XAVIER_UNIFORM:
+    RandUniform(w, -1.0 * sqrt(6.0 / (height + width)),
+                sqrt(6.0 / (height + width)));
+    break;
+  case WEIGHT_HE_UNIFORM:
+    RandUniform(w, -1.0 * sqrt(6.0 / (height)), sqrt(6.0 / (height)));
+    break;
+  default:
+    break;
+  }
+  return w;
+}
+
 } /* namespace nntrainer */
index bd14a6c..f61aa50 100644 (file)
@@ -346,7 +346,7 @@ int NeuralNetwork::init() {
         ml_loge("Error: BN layer shouldn't be first layer of network");
         return ML_ERROR_INVALID_PARAMETER;
       }
-      layers[i - 1]->setBNfallow(true);
+      layers[i - 1]->setBNfollow(true);
       status = bn_layer->setActivation((ActiType)parseType(
         iniparser_getstring(ini, (layers_name[i] + ":Activation").c_str(),
                             unknown),
@@ -484,7 +484,7 @@ int NeuralNetwork::init(std::shared_ptr<Optimizer> optimizer,
       NN_RETURN_STATUS();
       status = layers[i]->initialize(last);
       NN_RETURN_STATUS();
-      layers[i - 1]->setBNfallow(true);
+      layers[i - 1]->setBNfollow(true);
       break;
     default:
       break;