[neuralnet] Add alternatives to getLoss
authorParichay Kapoor <pk.kapoor@samsung.com>
Mon, 5 Oct 2020 05:05:09 +0000 (14:05 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 5 Oct 2020 06:26:35 +0000 (15:26 +0900)
getLoss used to get the current loss of the model which was based
on the previous batch of data which the network ran on.
This does not allow getting training/validation loss.
Added getTrainingLoss and getValidationLoss for this purpose.
And update getLoss description to include this information.

As MNIST application was using getLoss() which returns the
loss of the last ran element, this value was changed with #600
as with #600 last element is a batch of data than just 1 data element.
The application is updated to now compare all three loss with
updated values.
So, this patch fixes that bug in main branch as well.

Resolves #617

**Self evaluation:**
1. Build test: [x]Passed [ ]Failed [ ]Skipped
2. Run test: [x]Passed [ ]Failed [ ]Skipped

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
Applications/MNIST/jni/main.cpp
nntrainer/include/neuralnet.h

index 2484d12..985551f 100644 (file)
@@ -90,6 +90,8 @@ const float tolerance = 0.1;
 std::string data_path;
 
 float training_loss = 0.0;
+float validation_loss = 0.0;
+float last_batch_loss = 0.0;
 
 /**
  * @brief     step function
@@ -241,7 +243,9 @@ int getBatch_val(float **outVec, float **outLabel, bool *last,
 
 #if defined(APP_VALIDATE)
 TEST(MNIST_training, verify_accuracy) {
-  EXPECT_FLOAT_EQ(training_loss, 2.0374029);
+  EXPECT_FLOAT_EQ(training_loss, 2.3255470);
+  EXPECT_FLOAT_EQ(validation_loss, 2.3074534);
+  EXPECT_FLOAT_EQ(last_batch_loss, 2.2916341);
 }
 #endif
 
@@ -307,7 +311,9 @@ int main(int argc, char *argv[]) {
    */
   try {
     NN.train();
-    training_loss = NN.getLoss();
+    training_loss = NN.getTrainingLoss();
+    validation_loss = NN.getValidationLoss();
+    last_batch_loss = NN.getLoss();
   } catch (...) {
     std::cerr << "Error during train" << std::endl;
     return 0;
index 139de19..73182ff 100644 (file)
@@ -95,12 +95,24 @@ public:
   ~NeuralNetwork();
 
   /**
-   * @brief     Get Loss
+   * @brief     Get Loss from the previous ran batch of data
    * @retval    loss value
    */
   float getLoss();
 
   /**
+   * @brief     Get Loss from the previous epoch of training data
+   * @retval    loss value
+   */
+  float getTrainingLoss() { return training.loss; }
+
+  /**
+   * @brief     Get Loss from the previous epoch of validation data
+   * @retval    loss value
+   */
+  float getValidationLoss() { return validation.loss; }
+
+  /**
    * @brief     Get Learning rate
    * @retval    Learning rate
    */