Calculation of Validation Loss & Accuracy
authorjijoong.moon <jijoong.moon@samsung.com>
Thu, 13 Feb 2020 22:49:47 +0000 (07:49 +0900)
committer문지중/On-Device Lab(SR)/Principal Engineer/삼성전자 <jijoong.moon@samsung.com>
Thu, 13 Feb 2020 23:01:31 +0000 (08:01 +0900)
Calculate Validation Loss & Accuracy

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
Applications/Classification/jni/main.cpp
Applications/Classification/res/Classification.ini
include/layers.h
include/neuralnet.h
src/layers.cpp
src/neuralnet.cpp

index 3a2c8a0..6c97e66 100644 (file)
@@ -45,7 +45,9 @@
 /**
  * @brief     Data size for each category
  */
-#define TOTAL_DATA_SIZE 500
+#define TOTAL_TRAIN_DATA_SIZE 500
+
+#define TOTAL_VAL_DATA_SIZE 50
 
 /**
  * @brief     Number of category : Three
@@ -75,7 +77,8 @@ using namespace std;
  */
 string data_path;
 
-bool duplicate[TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE];
+bool duplicate[TOTAL_LABEL_SIZE * TOTAL_TRAIN_DATA_SIZE];
+bool valduplicate[TOTAL_LABEL_SIZE * TOTAL_VAL_DATA_SIZE];
 
 /**
  * @brief     step function
@@ -209,10 +212,19 @@ void getFeature(const string filename, vector<float> &feature_input) {
  * @param[out] feature_input save output of tflite
  * @param[out] feature_output save label data
  */
-void ExtractFeatures(std::string p, vector<vector<float>> &feature_input, vector<vector<float>> &feature_output) {
+void ExtractFeatures(std::string p, vector<vector<float>> &feature_input, vector<vector<float>> &feature_output,
+                     std::string type) {
   string total_label[10] = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"};
 
-  int trainingSize = TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE;
+  int data_size = TOTAL_TRAIN_DATA_SIZE;
+  bool val = false;
+
+  if (!type.compare("val")) {
+    data_size = TOTAL_VAL_DATA_SIZE;
+    val = true;
+  }
+
+  int trainingSize = TOTAL_LABEL_SIZE * data_size;
 
   feature_input.resize(trainingSize);
   feature_output.resize(trainingSize);
@@ -223,10 +235,16 @@ void ExtractFeatures(std::string p, vector<vector<float>> &feature_input, vector
     std::string path = p;
     path += total_label[i];
 
-    for (int j = 0; j < TOTAL_DATA_SIZE; j++) {
+    for (int j = 0; j < data_size; j++) {
       std::string img = path + "/";
       std::stringstream ss;
-      ss << std::setw(4) << std::setfill('0') << (j + 1);
+
+      if (val) {
+        ss << std::setw(4) << std::setfill('0') << (5000 - j);
+      } else {
+        ss << std::setw(4) << std::setfill('0') << (j + 1);
+      }
+
       img += ss.str() + ".bmp";
       printf("%s\n", img.c_str());
 
@@ -246,8 +264,9 @@ bool getMiniBatch(std::vector<std::vector<float>> inVec, std::vector<std::vector
   std::vector<int> memI;
   std::vector<int> memJ;
   int count = 0;
+  int data_size = TOTAL_TRAIN_DATA_SIZE;
 
-  for (int i = 0; i < TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE; i++) {
+  for (int i = 0; i < TOTAL_LABEL_SIZE * data_size; i++) {
     if (!duplicate[i])
       count++;
   }
@@ -257,7 +276,7 @@ bool getMiniBatch(std::vector<std::vector<float>> inVec, std::vector<std::vector
 
   count = 0;
   while (count < MINI_BATCH) {
-    int nomI = rangeRandom(0, TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE - 1);
+    int nomI = rangeRandom(0, TOTAL_LABEL_SIZE * data_size - 1);
     if (!duplicate[nomI]) {
       memI.push_back(nomI);
       duplicate[nomI] = true;
@@ -277,10 +296,17 @@ bool getMiniBatch(std::vector<std::vector<float>> inVec, std::vector<std::vector
   return true;
 }
 
-void save(std::vector<std::vector<float>> inVec, std::vector<std::vector<float>> inLabel) {
-  std::string file = "trainingSet.dat";
+void save(std::vector<std::vector<float>> inVec, std::vector<std::vector<float>> inLabel, std::string type) {
+  std::string file = type + "Set.dat";
+  unsigned int data_size;
+  if (!type.compare("training")) {
+    data_size = TOTAL_TRAIN_DATA_SIZE;
+  } else if (!type.compare("val")) {
+    data_size = TOTAL_VAL_DATA_SIZE;
+  }
+
   std::ofstream TrainigSet(file, std::ios::out | std::ios::binary);
-  for (unsigned int i = 0; i < TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE; i++) {
+  for (unsigned int i = 0; i < TOTAL_LABEL_SIZE * data_size; i++) {
     for (unsigned int j = 0; j < FEATURE_SIZE; j++) {
       TrainigSet.write((char *)&inVec[i][j], sizeof(float));
     }
@@ -289,16 +315,24 @@ void save(std::vector<std::vector<float>> inVec, std::vector<std::vector<float>>
   }
 }
 
-bool read(std::vector<std::vector<float>> &inVec, std::vector<std::vector<float>> &inLabel) {
-  std::string file = "trainingSet.dat";
+bool read(std::vector<std::vector<float>> &inVec, std::vector<std::vector<float>> &inLabel, std::string type) {
+  std::string file = type + "Set.dat";
+
+  unsigned int data_size;
+  if (!type.compare("training")) {
+    data_size = TOTAL_TRAIN_DATA_SIZE;
+  } else if (!type.compare("val")) {
+    data_size = TOTAL_VAL_DATA_SIZE;
+  }
+
   std::ifstream TrainigSet(file, std::ios::out | std::ios::binary);
   if (!TrainigSet.good())
     return false;
 
-  inVec.resize(TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE);
-  inLabel.resize(TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE);
+  inVec.resize(TOTAL_LABEL_SIZE * data_size);
+  inLabel.resize(TOTAL_LABEL_SIZE * data_size);
 
-  for (unsigned int i = 0; i < TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE; i++) {
+  for (unsigned int i = 0; i < TOTAL_LABEL_SIZE * data_size; i++) {
     inVec[i].resize(FEATURE_SIZE);
     for (unsigned int j = 0; j < FEATURE_SIZE; j++) {
       TrainigSet.read((char *)&inVec[i][j], sizeof(float));
@@ -329,13 +363,22 @@ int main(int argc, char *argv[]) {
   srand(time(NULL));
   std::string ini_file = data_path + "ini.bin";
   std::vector<std::vector<float>> inputVector, outputVector;
+  std::vector<std::vector<float>> inputValVector, outputValVector;
 
-  if (!read(inputVector, outputVector)) {
+  if (!read(inputVector, outputVector, "training")) {
     /**
      * @brief     Extract Feature
      */
-    ExtractFeatures(data_path, inputVector, outputVector);
-    save(inputVector, outputVector);
+    ExtractFeatures(data_path, inputVector, outputVector, "training");
+    save(inputVector, outputVector, "training");
+  }
+
+  if (!read(inputValVector, outputValVector, "val")) {
+    /**
+     * @brief     Extract Feature
+     */
+    ExtractFeatures(data_path, inputValVector, outputValVector, "val");
+    save(inputValVector, outputValVector, "val");
   }
 
   /**
@@ -350,9 +393,10 @@ int main(int argc, char *argv[]) {
    * @brief     back propagation
    */
   if (training) {
+    float trainingloss = 0.0;
     for (int i = 0; i < ITERATION; i++) {
       int count = 0;
-      for (int j = 0; j < TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE; j++) {
+      for (int j = 0; j < TOTAL_LABEL_SIZE * TOTAL_TRAIN_DATA_SIZE; j++) {
         duplicate[j] = false;
       }
 
@@ -367,10 +411,31 @@ int main(int argc, char *argv[]) {
         }
       }
 
+      trainingloss = NN.getLoss();
+
       Layers::Optimizer opt = NN.getOptimizer();
-      cout << "#" << i + 1 << "/" << ITERATION << " - Loss : " << NN.getLoss() << "( " << opt.decay_rate << " "
-           << opt.decay_steps << " : " << NN.getLearningRate() * pow(opt.decay_rate, (i / opt.decay_steps)) << " )"
-           << endl;
+
+      int right = 0;
+      float valloss = 0.0;
+
+      for (int j = 0; j < TOTAL_LABEL_SIZE; j++) {
+        for (int k = 0; k < TOTAL_VAL_DATA_SIZE; k++) {
+          Tensor X = Tensor({inputValVector[j * TOTAL_VAL_DATA_SIZE + k]});
+          Tensor Y2 = Tensor({outputValVector[j * TOTAL_VAL_DATA_SIZE + k]});
+          Tensor Y = NN.forwarding(X, Y2);
+          if (Y.argmax() == j)
+            right++;
+          valloss += NN.getLoss();
+        }
+      }
+
+      valloss = valloss / (float)(TOTAL_LABEL_SIZE * TOTAL_VAL_DATA_SIZE);
+
+      cout << "#" << i + 1 << "/" << ITERATION << " - Loss : " << trainingloss << " ( " << opt.decay_rate << " "
+           << opt.decay_steps << " : " << NN.getLearningRate() * pow(opt.decay_rate, (i / opt.decay_steps))
+           << " ) >> [ Accuracy : " << right / (float)(TOTAL_LABEL_SIZE * TOTAL_VAL_DATA_SIZE) * 100.0
+           << "% ] [ Validation Loss : " << valloss << " ] " << endl;
+
       NN.setLoss(0.0);
       if (training)
         NN.saveModel();
index 015c1c2..cb93043 100644 (file)
@@ -4,7 +4,7 @@ Type = NeuralNetwork    # Network Type : Regression, KNN, NeuralNetwork
 Layers = inputlayer \
          fc1layer \
         outputlayer    #Layers of Neuralnetwork
-Learning_rate = 0.0  # Learning Rate
+Learning_rate = 0.001  # Learning Rate
 Decay_rate = 0.96      # for the decay_rate for the decayed learning rate
 Decay_steps = 1000       # decay step for the exponential decayed learning rate
 Epoch = 30000          # Epoch 
@@ -31,12 +31,12 @@ Bias_zero = true    # Zero Bias
 Type = FullyConnectedLayer
 Id = 1
 Height = 62720         # Input Dimension ( = Weight Height )
-Width = 100            # Hidden Layer Dimension ( = Weight Width )
+Width = 128            # Hidden Layer Dimension ( = Weight Width )
 Bias_zero = true
 
 [outputlayer]
 Type = OutputLayer
 Id = 2
-Height = 100           # Hidden Layer Dimension ( = Weight Height )
+Height = 128           # Hidden Layer Dimension ( = Weight Height )
 Width = 10             # Output Layer Dimension ( = Weight Width )
 Bias_zero = true
index 12e1265..4520086 100644 (file)
@@ -100,6 +100,13 @@ class Layer {
   virtual Tensor forwarding(Tensor input) = 0;
 
   /**
+   * @brief     Forward Propation of neural Network
+   * @param[in] input Input Tensor taken by upper layer
+   * @retval    Output Tensor
+   */
+  virtual Tensor forwarding(Tensor input, Tensor output) = 0;
+
+  /**
    * @brief     Back Propation of neural Network
    * @param[in] input Input Tensor taken by lower layer
    * @param[in] iteration Epoch value for the ADAM Optimizer
@@ -249,6 +256,17 @@ class InputLayer : public Layer {
   Tensor forwarding(Tensor input);
 
   /**
+   * @brief     foward propagation : return Input Tensor
+   *            It return Input as it is.
+   * @param[in] input input Tensor from lower layer.
+   * @param[in] output label Tensor.
+   * @retval    return Input Tensor
+   */
+  Tensor forwarding(Tensor input, Tensor output){
+    return forwarding(input);
+  };
+
+  /**
    * @brief     Set Optimizer
    * @param[in] opt optimizer
    */
@@ -307,6 +325,17 @@ class FullyConnectedLayer : public Layer {
   Tensor forwarding(Tensor input);
 
   /**
+   * @brief     foward propagation : return Input Tensor
+   *            It return Input as it is.
+   * @param[in] input input Tensor from lower layer.
+   * @param[in] output label Tensor.
+   * @retval    Activation(W x input + B)
+   */
+  Tensor forwarding(Tensor input, Tensor output){
+    return forwarding (input);
+  };
+
+  /**
    * @brief     back propagation
    *            Calculate dJdB & dJdW & Update W & B
    * @param[in] input Input Tensor from lower layer
@@ -388,6 +417,14 @@ class OutputLayer : public Layer {
   Tensor forwarding(Tensor input);
 
   /**
+   * @brief     forward propagation with input and set loss
+   * @param[in] input Input Tensor from upper layer
+   * @param[in] output Label Tensor
+   * @retval    Activation(W x input + B)
+   */
+  Tensor forwarding(Tensor input, Tensor output);
+
+  /**
    * @brief     back propagation
    *            Calculate dJdB & dJdW & Update W & B
    * @param[in] input Input Tensor from lower layer
index 5450ab8..90a028f 100644 (file)
@@ -113,6 +113,14 @@ class NeuralNetwork {
   Tensor forwarding(Tensor input);
 
   /**
+   * @brief     forward propagation
+   * @param[in] input Input Tensor X
+   * @param[in] label Input Tensor Y2
+   * @retval    Output Tensor Y
+   */
+  Tensor forwarding(Tensor input, Tensor output);
+
+  /**
    * @brief     back propagation to update W & B
    * @param[in] input Input Tensor X
    * @param[in] expectedOutput Lable Tensor Y
index 38b151f..c4e94cc 100644 (file)
@@ -235,6 +235,50 @@ Tensor OutputLayer::forwarding(Tensor input) {
   return hidden;
 }
 
+Tensor OutputLayer::forwarding(Tensor input, Tensor output) {
+  Input = input;
+  hidden = input.dot(Weight).add(Bias).applyFunction(activation);
+  Tensor Y2 = output;
+  Tensor Y = hidden.softmax();
+  float lossSum = 0.0;
+
+  switch (cost) {
+    case COST_CATEGORICAL: {
+      Tensor temp = ((Y2.multiply(-1.0).transpose().dot(Y.add(opt.epsilon).applyFunction(log_float)))
+                         .subtract(Y2.multiply(-1.0).add(1.0).transpose().dot(
+                             Y.multiply(-1.0).add(1.0).add(opt.epsilon).applyFunction(log_float))));
+      loss = (1.0 / Y.Mat2Vec().size()) * temp.Mat2Vec()[0];
+    } break;
+    case COST_MSR: {
+      Tensor sub = Y2.subtract(Y);
+      Tensor l = (sub.multiply(sub)).sum().multiply(0.5);
+      std::vector<float> t = l.Mat2Vec();
+      for (int i = 0; i < l.getBatch(); i++) {
+        lossSum += t[i];
+      }
+
+      loss = lossSum / (float)l.getBatch();
+    } break;
+    case COST_ENTROPY: {
+      Tensor l = (Y2.multiply(Y.applyFunction(log_float))
+                      .add((Y2.multiply(-1.0).add(1.0)).multiply((Y.multiply(-1.0).add(1.0)).applyFunction(log_float))))
+                     .multiply(-1.0 / (Y2.getWidth()))
+                     .sum();
+
+      std::vector<float> t = l.Mat2Vec();
+
+      for (int i = 0; i < l.getBatch(); i++) {
+        lossSum += t[i];
+      }
+      loss = lossSum / (float)l.getBatch();
+    } break;
+    case COST_UNKNOWN:
+    default:
+      break;
+  }
+  return hidden;
+}
+
 void OutputLayer::read(std::ifstream &file) {
   Weight.read(file);
   Bias.read(file);
index ecee386..5920439 100644 (file)
@@ -282,6 +282,18 @@ Tensor NeuralNetwork::forwarding(Tensor input) {
 }
 
 /**
+ * @brief     forward propagation using layers object which has layer
+ */
+Tensor NeuralNetwork::forwarding(Tensor input, Tensor output) {
+  Tensor X = input;
+  Tensor Y2 = output;
+  for (unsigned int i = 0; i < layers.size(); i++) {
+    X = layers[i]->forwarding(X,Y2);
+  }
+  return X;
+}
+
+/**
  * @brief     back propagation
  *            Call backwarding function of layer in reverse order
  *            No need to call at first Input Layer (No data to be updated)