[model] Add epoch complete callback
authorhyunil park <hyunil46.park@samsung.com>
Mon, 15 May 2023 01:58:40 +0000 (10:58 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 23 May 2023 10:09:49 +0000 (19:09 +0900)
- Called the end of an epoch
- Users can do what they need at the end of each epoch. e.g. get RunStats.

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
api/ccapi/include/model.h
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h

index 7485506..7706739 100644 (file)
@@ -169,9 +169,13 @@ public:
    * @param[in] values hyper parameters
    * @param[in] stop_cb callback function to decide stop training or not
    * ~~~~~
-   * @a user_data user_data to be used in stop_cb
+   * @a stop_user_data user_data to be used in stop_cb
    * @a bool true if stop the training
    * ~~~~~
+   * @param[in] epoch_complete_cb Called the end of an epoch
+   * ~~~~~
+   * @a epoch_user_data user_data to be used in epoch_complete_cb
+   * ~~~~~
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    * @details   This function accepts vector of properties in the format -
@@ -179,8 +183,11 @@ public:
    */
   virtual int train(const std::vector<std::string> &values = {},
                     std::function<bool(void *)> stop_cb =
-                      [](void *user_data) { return false; },
-                    void *user_data = nullptr) = 0;
+                      [](void *stop_user_data) { return false; },
+                    void *stop_user_data = nullptr,
+                    std::function<void(void *)> epoch_complete_cb =
+                      [](void *epoch_user_data) { return false; },
+                    void *epoch_user_data = nullptr) = 0;
 
   /**
    * @brief     Run Model train with callback function by user
index 09d2ed8..57bb8eb 100644 (file)
@@ -697,7 +697,10 @@ int NeuralNetwork::deallocate() {
 }
 
 int NeuralNetwork::train(const std::vector<std::string> &values,
-                         std::function<bool(void *)> stop_cb, void *user_data) {
+                         std::function<bool(void *)> stop_cb,
+                         void *stop_user_data,
+                         std::function<void(void *)> epoch_complete_cb,
+                         void *epoch_user_data) {
   int status = ML_ERROR_NONE;
 
   if (data_buffers[static_cast<int>(DatasetModeType::MODE_TRAIN)] == nullptr) {
@@ -719,7 +722,8 @@ int NeuralNetwork::train(const std::vector<std::string> &values,
   status = allocate(ExecutionMode::TRAIN);
   NN_RETURN_STATUS();
 
-  status = train_run(stop_cb, user_data);
+  status =
+    train_run(stop_cb, stop_user_data, epoch_complete_cb, epoch_user_data);
   NN_RETURN_STATUS();
 
   /**
@@ -734,8 +738,10 @@ int NeuralNetwork::train(const std::vector<std::string> &values,
 /**
  * @brief     Run NeuralNetwork train with callback function by user
  */
-int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
-                             void *user_data) {
+int NeuralNetwork::train_run(
+  std::function<bool(void *userdata)> stop_cb, void *stop_user_data,
+  std::function<void(void *userdata)> epoch_complete_cb,
+  void *epoch_user_data) {
   int status = ML_ERROR_NONE;
 
   if (!std::get<props::ContinueTrain>(model_flex_props)) {
@@ -814,21 +820,21 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
     return stat;
   };
 
-  auto train_for_iteration = [this, stop_cb, user_data](RunStats &stat,
-                                                        DataBuffer &buffer) {
-    forwarding(true, stop_cb, user_data);
-    backwarding(iter++, stop_cb, user_data);
+  auto train_for_iteration =
+    [this, stop_cb, stop_user_data](RunStats &stat, DataBuffer &buffer) {
+      forwarding(true, stop_cb, stop_user_data);
+      backwarding(iter++, stop_cb, stop_user_data);
 
-    // To avoid unconsidered memory leak, we need to clear the cache
-    model_graph.flushCache();
+      // To avoid unconsidered memory leak, we need to clear the cache
+      model_graph.flushCache();
 
-    if (!stop_cb(user_data)) {
-      std::cout << "#" << epoch_idx << "/" << getEpochs();
-      ml_logi("# %d / %d", epoch_idx, getEpochs());
-      auto loss = getLoss();
-      buffer.displayProgress(stat.num_iterations, loss);
-    }
-  };
+      if (!stop_cb(stop_user_data)) {
+        std::cout << "#" << epoch_idx << "/" << getEpochs();
+        ml_logi("# %d / %d", epoch_idx, getEpochs());
+        auto loss = getLoss();
+        buffer.displayProgress(stat.num_iterations, loss);
+      }
+    };
 
   auto update_train_stat = [this](RunStats &stat,
                                   const std::vector<Tensor> &outputs,
@@ -837,8 +843,8 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
     stat.num_iterations++;
   };
 
-  auto train_epoch_end = [this, stop_cb, user_data](RunStats &stat,
-                                                    DataBuffer &buffer) {
+  auto train_epoch_end = [this, stop_cb, stop_user_data](RunStats &stat,
+                                                         DataBuffer &buffer) {
     if (stat.num_iterations != 0) {
       stat.loss /= static_cast<float>(stat.num_iterations);
     } else {
@@ -846,7 +852,7 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
       return;
     }
     auto &save_path = std::get<props::SavePath>(model_flex_props);
-    if (!stop_cb(user_data)) {
+    if (!stop_cb(stop_user_data)) {
       if (!save_path.empty()) {
         save(save_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
       }
@@ -864,9 +870,9 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
     }
   };
 
-  auto eval_for_iteration = [this, batch_size, stop_cb,
-                             user_data](RunStats &stat, DataBuffer &buffer) {
-    forwarding(false, stop_cb, user_data);
+  auto eval_for_iteration = [this, batch_size, stop_cb, stop_user_data](
+                              RunStats &stat, DataBuffer &buffer) {
+    forwarding(false, stop_cb, stop_user_data);
   };
 
   auto update_eval_stat = [batch_size, &update_train_stat](
@@ -918,7 +924,7 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
   ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
           epoch_idx + 1, getEpochs());
   for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
-    if (stop_cb(user_data)) {
+    if (stop_cb(stop_user_data)) {
       --epoch_idx;
       break;
     }
@@ -929,6 +935,7 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
                              update_eval_stat, eval_epoch_end, validation);
     }
     std::cout << '\n';
+    epoch_complete_cb(epoch_user_data);
   }
   PROFILE_MEM_ANNOTATE("TRAIN END");
 
index 43f1e83..0e8820b 100644 (file)
@@ -291,16 +291,22 @@ public:
    * @param[in] values hyper parameters
    * @param[in] stop_cb callback function to decide stop training or not
    * ~~~~~
-   * @a user_data user_data to be used in stop_cb
+   * @a stop_user_data user_data to be used in stop_cb
    * @a bool true if stop the training
    * ~~~~~
+   * @param[in] epoch_complete_cb Called the end of an epoch.
+   * @a epoch_user_data user_data to be used in epoch_complete_cb
+   * ~~~~~
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
   int train(const std::vector<std::string> &values = {},
             std::function<bool(void *)> stop_cb =
-              [](void *user_data) { return false; },
-            void *user_data = nullptr) override;
+              [](void *stop_user_data) { return false; },
+            void *stop_user_data = nullptr,
+            std::function<void(void *)> epoch_complete_cb =
+              [](void *epoch_user_data) { return false; },
+            void *epoch_user_data = nullptr) override;
 
   /**
    * @brief     Run NeuralNetwork inference
@@ -628,12 +634,16 @@ private:
   /**
    * @brief     Run NeuralNetwork train
    * @param[in] stop_cb callback function to decide stop training or not
+   * @param[in] epoch_complete_cb Called the end of an epoch.
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
   int train_run(std::function<bool(void *)> stop_cb =
                   [](void *) { return false; },
-                void *user_data = nullptr);
+                void *user_data = nullptr,
+                std::function<void(void *)> epoch_complete_cb =
+                  [](void *) { return false; },
+                void *data = nullptr);
 
   /**
    * @brief     Swap function for the class