add Weight initialization Method
authorjijoong.moon <jijoong.moon@samsung.com>
Mon, 16 Mar 2020 03:44:59 +0000 (12:44 +0900)
committer문지중/On-Device Lab(SR)/Principal Engineer/삼성전자 <jijoong.moon@samsung.com>
Mon, 16 Mar 2020 03:58:34 +0000 (12:58 +0900)
Add Weight Initialization Method

    "lecun_normal"  : LeCun Normal Initialization
    "lecun_uniform"  : LeCun Uniform Initialization
    "xavier_normal"  : Xavier Normal Initialization
    "xavier_uniform"  : Xavier Uniform Initialization
    "he_normal"  : He Normal Initialization
    "he_uniform"  : He Uniform Initialization

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
include/layers.h
include/neuralnet.h
src/layers.cpp
src/neuralnet.cpp

index 179dbf5..3c5c0de 100644 (file)
@@ -69,6 +69,25 @@ typedef enum { ACT_TANH, ACT_SIGMOID, ACT_RELU, ACT_UNKNOWN } acti_type;
 typedef enum { LAYER_IN, LAYER_FC, LAYER_OUT, LAYER_BN, LAYER_UNKNOWN } layer_type;
 
 /**
+ * @brief     Enumeration of Weight Initialization Type
+ *            0. WEIGHT_LECUN_NORMAL ( LeCun normal initialization )
+ *            1. WEIGHT_LECUN_UNIFORM (LeCun uniform initialization )
+ *            2. WEIGHT_XAVIER_NORMAL ( Xavier normal initialization )
+ *            3. WEIGHT_XAVIER_UNIFORM ( Xavier uniform initialization )
+ *            4. WEIGHT_HE_NORMAL ( He normal initialization )
+ *            5. WEIGHT_HE_UNIFORM ( He uniform initialization )
+ */
+typedef enum {
+  WEIGHT_LECUN_NORMAL,
+  WEIGHT_LECUN_UNIFORM,
+  WEIGHT_XAVIER_NORMAL,
+  WEIGHT_XAVIER_UNIFORM,
+  WEIGHT_HE_NORMAL,
+  WEIGHT_HE_UNIFORM,
+  WEIGHT_UNKNOWN
+} weightIni_type;
+
+/**
  * @brief     type for the Optimizor to save hyper-parameter
  */
 typedef struct {
@@ -123,8 +142,9 @@ class Layer {
    * @param[in] w Width
    * @param[in] id index of this layer
    * @param[in] init_zero Bias initialization with zero
+   * @param[in] wini Weight Initialization Scheme
    */
-  virtual void initialize(int b, int h, int w, int id, bool init_zero) = 0;
+  virtual void initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini) = 0;
 
   /**
    * @brief     read layer Weight & Bias data from file
@@ -282,8 +302,9 @@ class InputLayer : public Layer {
    * @param[in] w width
    * @param[in] id index of this layer
    * @param[in] init_zero boolean to set Bias zero
+   * @param[in] wini Weight Initialization Scheme
    */
-  void initialize(int b, int h, int w, int id, bool init_zero);
+  void initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini);
 
   /**
    * @brief     Copy Layer
@@ -380,8 +401,9 @@ class FullyConnectedLayer : public Layer {
    * @param[in] w width
    * @param[in] id layer index
    * @param[in] init_zero boolean to set Bias zero
+   * @param[in] wini Weight Initialization Scheme
    */
-  void initialize(int b, int h, int w, int id, bool init_zero);
+  void initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini);
 
  private:
   Tensor Weight;
@@ -467,8 +489,9 @@ class OutputLayer : public Layer {
    * @param[in] w width
    * @param[in] id layer index
    * @param[in] init_zero boolean to set Bias zero
+   * @param[in] wini Weight Initialization Scheme
    */
-  void initialize(int b, int w, int h, int id, bool init_zero);
+  void initialize(int b, int w, int h, int id, bool init_zero, weightIni_type wini);
 
   /**
    * @brief     get Loss value
@@ -577,8 +600,9 @@ class BatchNormalizationLayer : public Layer {
    * @param[in] w width
    * @param[in] id layer index
    * @param[in] init_zero boolean to set Bias zero
+   * @param[in] wini Weight Initialization Scheme
    */
-  void initialize(int b, int h, int w, int id, bool init_zero);
+  void initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini);
 
  private:
   Tensor Weight;
index 90a028f..fa226dc 100644 (file)
@@ -51,9 +51,11 @@ typedef enum { NET_KNN, NET_REG, NET_NEU, NET_UNKNOWN } net_type;
  *            2. NET     ( Network Token )
  *            3. ACTI    ( Activation Token )
  *            4. LAYER   ( Layer Token )
- *            5. UNKNOWN
+ *            5. WEIGHTINI  ( Weight Initialization Token )
+ *            6. UNKNOWN
  */
-typedef enum { TOKEN_OPT, TOKEN_COST, TOKEN_NET, TOKEN_ACTI, TOKEN_LAYER, TOKEN_UNKNOWN } input_type;
+typedef enum { TOKEN_OPT, TOKEN_COST, TOKEN_NET, TOKEN_ACTI, TOKEN_LAYER, TOKEN_WEIGHTINI, TOKEN_UNKNOWN } input_type;
+
 
 /**
  * @class   NeuralNetwork Class
@@ -214,6 +216,11 @@ class NeuralNetwork {
   Layers::cost_type cost;
 
   /**
+   * @brief     Weight Initialization type
+   */
+  Layers::weightIni_type weightini;
+
+  /**
    * @brief     Model path to save or read
    */
   std::string model;
index 5010841..3a4e148 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "include/layers.h"
 #include <assert.h>
+#include <random>
 
 /**
  * @brief     random function
@@ -93,6 +94,63 @@ float ReluPrime(float x) {
   }
 }
 
+static void WeightInitialization(Tensor W, unsigned int width, unsigned int height, Layers::weightIni_type init_type) {
+  std::random_device rd;
+  std::mt19937 gen(rd());
+
+  switch (init_type) {
+    case Layers::WEIGHT_LECUN_NORMAL: {
+      std::normal_distribution<float> dist(0, sqrt(1 / height));
+      for (unsigned int i = 0; i < width; ++i)
+        for (unsigned int j = 0; j < height; ++j) {
+          float f = dist(gen);
+          W.setValue(0, j, i, f);
+        }
+    } break;
+    case Layers::WEIGHT_LECUN_UNIFORM: {
+      std::uniform_real_distribution<float> dist(-1.0 * sqrt(1.0 / height), sqrt(1.0 / height));
+      for (unsigned int i = 0; i < width; ++i)
+        for (unsigned int j = 0; j < height; ++j) {
+          float f = dist(gen);
+          W.setValue(0, j, i, f);
+        }
+    } break;
+    case Layers::WEIGHT_XAVIER_NORMAL: {
+      std::normal_distribution<float> dist(0, sqrt(2.0 / (width + height)));
+      for (unsigned int i = 0; i < width; ++i)
+        for (unsigned int j = 0; j < height; ++j) {
+          float f = dist(gen);
+          W.setValue(0, j, i, f);
+        }
+    } break;
+    case Layers::WEIGHT_XAVIER_UNIFORM: {
+      std::uniform_real_distribution<float> dist(-1.0 * sqrt(6.0 / (height + width)), sqrt(6.0 / (height + width)));
+      for (unsigned int i = 0; i < width; ++i)
+        for (unsigned int j = 0; j < height; ++j) {
+          float f = dist(gen);
+          W.setValue(0, j, i, f);
+        }
+    } break;
+    case Layers::WEIGHT_HE_NORMAL: {
+      std::normal_distribution<float> dist(0, sqrt(2.0 / (height)));
+      for (unsigned int i = 0; i < width; ++i)
+        for (unsigned int j = 0; j < height; ++j) {
+          float f = dist(gen);
+          W.setValue(0, j, i, f);
+        }
+    } break;
+    case Layers::WEIGHT_HE_UNIFORM: {
+      std::uniform_real_distribution<float> dist(-1.0 * sqrt(6.0 / (height)), sqrt(6.0 / (height)));
+      for (unsigned int i = 0; i < width; ++i)
+        for (unsigned int j = 0; j < height; ++j) {
+          float f = dist(gen);
+          W.setValue(0, j, i, f);
+        }
+    } break;
+    default:
+      break;
+  }
+}
 namespace Layers {
 
 void InputLayer::setOptimizer(Optimizer opt) {
@@ -132,7 +190,7 @@ Tensor InputLayer::forwarding(Tensor input) {
   return Input;
 }
 
-void InputLayer::initialize(int b, int h, int w, int id, bool init_zero) {
+void InputLayer::initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini) {
   this->batch = b;
   this->width = w;
   this->height = h;
@@ -140,7 +198,7 @@ void InputLayer::initialize(int b, int h, int w, int id, bool init_zero) {
   this->bnfallow = false;
 }
 
-void FullyConnectedLayer::initialize(int b, int h, int w, int id, bool init_zero) {
+void FullyConnectedLayer::initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini) {
   this->batch = b;
   this->width = w;
   this->height = h;
@@ -151,7 +209,8 @@ void FullyConnectedLayer::initialize(int b, int h, int w, int id, bool init_zero
   Weight = Tensor(h, w);
   Bias = Tensor(1, w);
 
-  Weight = Weight.applyFunction(random);
+  WeightInitialization(Weight, w, h, wini);
+
   if (init_zero) {
     Bias.setZero();
   } else {
@@ -259,7 +318,7 @@ Tensor FullyConnectedLayer::backwarding(Tensor derivative, int iteration) {
   return ret;
 }
 
-void OutputLayer::initialize(int b, int h, int w, int id, bool init_zero) {
+void OutputLayer::initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini) {
   this->batch = b;
   this->width = w;
   this->height = h;
@@ -270,7 +329,9 @@ void OutputLayer::initialize(int b, int h, int w, int id, bool init_zero) {
   this->cost = cost;
   this->bnfallow = false;
 
-  Weight = Weight.applyFunction(random);
+  // Weight = Weight.applyFunction(random);
+  WeightInitialization(Weight, w, h, wini);
+
   if (init_zero) {
     Bias.setZero();
   } else {
@@ -477,7 +538,7 @@ Tensor OutputLayer::backwarding(Tensor label, int iteration) {
   return ret;
 }
 
-void BatchNormalizationLayer::initialize(int b, int h, int w, int id, bool init_zero) {
+void BatchNormalizationLayer::initialize(int b, int h, int w, int id, bool init_zero, weightIni_type wini) {
   this->batch = b;
   this->width = w;
   this->height = h;
index 3d9de1f..9e21407 100644 (file)
@@ -93,6 +93,17 @@ std::vector<std::string> activation_string = {"tanh", "sigmoid", "relu"};
 std::vector<std::string> layer_string = {"InputLayer", "FullyConnectedLayer", "OutputLayer", "BatchNormalizationLayer"};
 
 /**
+ * @brief     Weight Initialization Type String from configure file
+ *            "lecun_normal"  : LeCun Normal Initialization
+ *            "lecun_uniform"  : LeCun Uniform Initialization
+ *            "xavier_normal"  : Xavier Normal Initialization
+ *            "xavier_uniform"  : Xavier Uniform Initialization
+ *            "he_normal"  : He Normal Initialization
+ *            "he_uniform"  : He Uniform Initialization
+ */
+  std::vector<std::string> weightini_string = {"lecun_normal", "lecun_uniform", "xavier_normal", "xavier_uniform", "he_normal", "he_uniform"};
+
+/**
  * @brief     Check Existance of File
  * @param[in] filename file path to check
  * @retval    boolean true if exists
@@ -171,6 +182,14 @@ unsigned int parseType(std::string ll, input_type t) {
       }
       ret = i - 1;
       break;
+    case TOKEN_WEIGHTINI:
+      for (i = 0; i < weightini_string.size(); i++) {
+        if (caseInSensitiveCompare(weightini_string[i], ll)) {
+          return (i);
+        }
+      }
+      ret = i - 1;
+      break;
     case TOKEN_UNKNOWN:
     default:
       ret = 3;
@@ -209,6 +228,7 @@ void NeuralNetwork::init() {
   opt.type = (Layers::opt_type)parseType(iniparser_getstring(ini, "Network:Optimizer", NULL), TOKEN_OPT);
   opt.activation = (Layers::acti_type)parseType(iniparser_getstring(ini, "Network:Activation", NULL), TOKEN_ACTI);
   cost = (Layers::cost_type)parseType(iniparser_getstring(ini, "Network:Cost", NULL), TOKEN_COST);
+  weightini = (Layers::weightIni_type)parseType(iniparser_getstring(ini, "Network:WeightIni", "xavier_normal"), TOKEN_WEIGHTINI);
 
   model = iniparser_getstring(ini, "Network:Model", "model.bin");
   batchsize = iniparser_getint(ini, "Network:minibatch", 1);
@@ -256,7 +276,7 @@ void NeuralNetwork::init() {
       case Layers::LAYER_IN: {
         Layers::InputLayer *inputlayer = new (Layers::InputLayer);
         inputlayer->setType(t);
-        inputlayer->initialize(batchsize, 1, HiddenSize[i], id, b_zero);
+        inputlayer->initialize(batchsize, 1, HiddenSize[i], id, b_zero, weightini);
         inputlayer->setOptimizer(opt);
         inputlayer->setNormalization(iniparser_getboolean(ini, (layers_name[i] + ":Normalization").c_str(), false));
         inputlayer->setStandardization(iniparser_getboolean(ini, (layers_name[i] + ":Standardization").c_str(), false));
@@ -265,14 +285,14 @@ void NeuralNetwork::init() {
       case Layers::LAYER_FC: {
         Layers::FullyConnectedLayer *fclayer = new (Layers::FullyConnectedLayer);
         fclayer->setType(t);
-        fclayer->initialize(batchsize, HiddenSize[i - 1], HiddenSize[i], id, b_zero);
+        fclayer->initialize(batchsize, HiddenSize[i - 1], HiddenSize[i], id, b_zero, weightini);
         fclayer->setOptimizer(opt);
         layers.push_back(fclayer);
       } break;
       case Layers::LAYER_OUT: {
         Layers::OutputLayer *outputlayer = new (Layers::OutputLayer);
         outputlayer->setType(t);
-        outputlayer->initialize(batchsize, HiddenSize[i - 1], HiddenSize[i], id, b_zero);
+        outputlayer->initialize(batchsize, HiddenSize[i - 1], HiddenSize[i], id, b_zero, weightini);
         outputlayer->setOptimizer(opt);
         outputlayer->setCost(cost);
         outputlayer->setSoftmax(iniparser_getboolean(ini, (layers_name[i] + ":Softmax").c_str(), false));
@@ -282,7 +302,7 @@ void NeuralNetwork::init() {
         Layers::BatchNormalizationLayer *bnlayer = new (Layers::BatchNormalizationLayer);
         bnlayer->setType(t);
         bnlayer->setOptimizer(opt);
-        bnlayer->initialize(batchsize, 1, HiddenSize[i], id, b_zero);
+        bnlayer->initialize(batchsize, 1, HiddenSize[i], id, b_zero, weightini);
         layers.push_back(bnlayer);
         layers[i - 1]->setBNfallow(true);
       } break;