[NeuralNet] Fix testset run
authorJihoon Lee <jhoon.it.lee@samsung.com>
Sat, 28 Aug 2021 09:19:32 +0000 (18:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 7 Sep 2021 03:57:48 +0000 (12:57 +0900)
This patch fixes testset running while cleaning up the code to reuse
repeating logics with callback

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h

index ba1e39f..0ed3310 100644 (file)
@@ -632,25 +632,26 @@ int NeuralNetwork::train_run() {
     return ML_ERROR_INVALID_PARAMETER;
   }
 
-  for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
-    training.loss = 0.0f;
-
+  /**
+   * @brief run a single epoch with given callback, @a auto is used instead of
+   * std::function for performance measure
+   * @param buffer buffer to run
+   * @param shuffle whether to shuffle or not
+   * @param on_iteration_fetch function that will recieve reference to stat,
+   * buffer which will be called every time data is fetched and set
+   * @param on_epoch_end function that will recieve reference to stat,
+   * buffer which will be called on the epoch end
+   */
+  auto run_epoch = [this, &in, &label, &in_dims, &label_dims](
+                     DataBuffer *buffer, bool shuffle,
+                     auto &&on_iteration_fetch, auto &&on_epoch_end) {
+    /// @todo managing metrics must be handled here as well!! for now it is
+    /// handled in individual callbacks
+    RunStats stat;
     std::future<std::shared_ptr<IterationQueue>> future_iq =
-      train_buffer->startFetchWorker(in_dims, label_dims, true);
-
-    // /// @todo make this working, test buffer is running but doing nothing
-    // if (test_buffer != nullptr && test_buffer->isValid()) {
-    //   status = test_buffer->run();
-    //   if (status != ML_ERROR_NONE) {
-    //     test_buffer->clear();
-    //     return status;
-    //   }
-    // }
-
-    int count = 0;
-
+      buffer->startFetchWorker(in_dims, label_dims, shuffle);
     while (true) {
-      ScopedView<Iteration> iter_view = train_buffer->fetch();
+      ScopedView<Iteration> iter_view = buffer->fetch();
       if (iter_view.isEmpty()) {
         break;
       }
@@ -663,74 +664,73 @@ int NeuralNetwork::train_run() {
       in = iteration.getInputsRef().front();
       label = iteration.getLabelsRef().front();
 
-      forwarding(true);
-      backwarding(iter++);
+      on_iteration_fetch(stat, *buffer);
+    }
+    future_iq.get();
+    on_epoch_end(stat, *buffer);
 
-      std::cout << "#" << epoch_idx << "/" << epochs;
-      float loss = getLoss();
-      train_buffer->displayProgress(count++, loss);
-      training.loss += loss;
+    if (stat.num_iterations == 0) {
+      throw std::runtime_error("No data came while buffer ran");
     }
 
-    future_iq.get();
+    return stat;
+  };
 
-    if (count == 0)
-      throw std::runtime_error("No training data");
+  auto train_for_iteration = [this](RunStats &stat, DataBuffer &buffer) {
+    forwarding(true);
+    backwarding(iter++);
 
-    training.loss /= count;
+    std::cout << "#" << epoch_idx << "/" << epochs;
+    auto loss = getLoss();
+    stat.loss += loss;
+    buffer.displayProgress(stat.num_iterations++, loss);
+  };
+
+  auto train_epoch_end = [this](RunStats &stat, DataBuffer &buffer) {
+    stat.loss /= static_cast<float>(stat.num_iterations);
     if (!save_path.empty()) {
       save(save_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
     }
 
     std::cout << "#" << epoch_idx << "/" << epochs
-              << " - Training Loss: " << training.loss;
-
-    if (valid_buffer != nullptr) {
-      int right = 0;
-      validation.loss = 0.0f;
-      unsigned int tcases = 0;
-
-      std::future<std::shared_ptr<IterationQueue>> future_iq =
-        valid_buffer->startFetchWorker(in_dims, label_dims, false);
-
-      while (true) {
-        ScopedView<Iteration> iter_view = valid_buffer->fetch();
-        if (iter_view.isEmpty()) {
-          break;
-        }
-        auto &iter = iter_view.get();
-        if (iter.batch() != batch_size) {
-          /// @todo support partial batch
-          continue;
-        }
-        /// @todo multiple input support
-        in = iter.getInputsRef().front();
-        label = iter.getLabelsRef().front();
-
-        forwarding(false);
-        auto model_out = output.argmax();
-        auto label_out = label.argmax();
-        for (unsigned int b = 0; b < batch_size; b++) {
-          if (model_out[b] == label_out[b])
-            right++;
-        }
-        validation.loss += getLoss();
-        tcases++;
-      }
+              << " - Training Loss: " << stat.loss;
+  };
 
-      future_iq.get();
+  auto eval_for_iteration = [this, &output, &label](RunStats &stat,
+                                                    DataBuffer &buffer) {
+    forwarding(false);
+    auto model_out = output.argmax();
+    auto label_out = label.argmax();
+    for (unsigned int b = 0; b < batch_size; b++) {
+      if (model_out[b] == label_out[b])
+        stat.num_correct_predictions++;
+    }
+    stat.num_iterations++;
+    stat.loss += getLoss();
+  };
 
-      if (tcases == 0) {
-        ml_loge("Error : 0 test cases");
-        status = ML_ERROR_INVALID_PARAMETER;
-        return status;
-      }
-      validation.loss /= (float)(tcases);
-      validation.accuracy = right / (float)(tcases * batch_size) * 100.0f;
-      std::cout << " >> [ Accuracy: " << validation.accuracy
-                << "% - Validation Loss : " << validation.loss << " ] ";
+  auto eval_epoch_end = [this](RunStats &stat, DataBuffer &buffer) {
+    stat.loss /= static_cast<float>(stat.num_iterations);
+    stat.accuracy = stat.num_correct_predictions /
+                    static_cast<float>(stat.num_iterations * batch_size) *
+                    100.0f;
+    std::cout << " >> [ Accuracy: " << validation.accuracy
+              << "% - Validation Loss : " << validation.loss << " ]";
+  };
+
+  for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
+    training =
+      run_epoch(train_buffer.get(), true, train_for_iteration, train_epoch_end);
+    if (valid_buffer) {
+      validation = run_epoch(valid_buffer.get(), false, eval_for_iteration,
+                             eval_epoch_end);
     }
-    std::cout << std::endl;
+    std::cout << '\n';
+  }
+  if (test_buffer) {
+    std::cout << "Evaluation with test data...\n";
+    testing =
+      run_epoch(test_buffer.get(), false, eval_for_iteration, eval_epoch_end);
   }
 
   return status;
index 8b6cf9d..17d2320 100644 (file)
@@ -62,12 +62,19 @@ using DatasetModeType = ml::train::DatasetModeType;
 /**
  * @brief     Statistics from running or training a model
  */
-typedef struct RunStats_ {
-  float accuracy; /** accuracy of the model */
-  float loss;     /** loss of the model */
-
-  RunStats_() : accuracy(0), loss(0) {}
-} RunStats;
+struct RunStats {
+  float accuracy;     /** accuracy of the model */
+  float loss;         /** loss of the model */
+  int num_iterations; /** number of iterations done on this stat */
+  unsigned int
+    num_correct_predictions; /** number of right sample on this run */
+
+  RunStats() :
+    accuracy(0),
+    loss(0),
+    num_iterations(0),
+    num_correct_predictions(0) {}
+};
 
 /**
  * @class   NeuralNetwork Class