[neuralnet] add log about training
authorhyeonseok lee <hs89.lee@samsung.com>
Thu, 29 Sep 2022 09:19:26 +0000 (18:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 30 Sep 2022 06:44:32 +0000 (15:44 +0900)
 - Added log when start/finish training
 - Added log when get current epoch is called

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h

index 9730b9e..c9ed19e 100644 (file)
@@ -113,6 +113,13 @@ int NeuralNetwork::loadFromConfig(const std::string &config) {
   return ML_ERROR_NONE;
 }
 
+unsigned int NeuralNetwork::getCurrentEpoch() {
+#ifdef DEBUG
+  ml_logd("[NNTrainer] Current epoch: %d", epoch_idx);
+#endif
+  return epoch_idx;
+};
+
 void NeuralNetwork::setProperty(const std::vector<std::string> &values) {
   auto left_props = loadProperties(values, model_props);
   setTrainConfig(left_props);
@@ -745,14 +752,17 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb) {
     return stat;
   };
 
-  auto train_for_iteration = [this](RunStats &stat, DataBuffer &buffer) {
+  auto train_for_iteration = [this, stop_cb](RunStats &stat,
+                                             DataBuffer &buffer) {
     forwarding(true);
     backwarding(iter++);
 
-    std::cout << "#" << epoch_idx << "/" << getEpochs();
-    ml_logi("# %d / %d", epoch_idx, getEpochs());
-    auto loss = getLoss();
-    buffer.displayProgress(stat.num_iterations, loss);
+    if (!stop_cb(nullptr)) {
+      std::cout << "#" << epoch_idx << "/" << getEpochs();
+      ml_logi("# %d / %d", epoch_idx, getEpochs());
+      auto loss = getLoss();
+      buffer.displayProgress(stat.num_iterations, loss);
+    }
   };
 
   auto update_train_stat = [this](RunStats &stat,
@@ -774,6 +784,12 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb) {
                 << " - Training Loss: " << stat.loss;
       ml_logi("# %d / %d - Training Loss: %f", epoch_idx, getEpochs(),
               stat.loss);
+      ml_logd("[NNTrainer] Training epoch %d / %d finished successfully.",
+              epoch_idx, getEpochs());
+    } else {
+      ml_logd("[NNTrainer] Training stopped by stop callback function during "
+              "epoch %d.",
+              epoch_idx);
     }
   };
 
@@ -822,6 +838,8 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb) {
   };
 
   auto epochs = getEpochs();
+  ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
+          epoch_idx + 1, getEpochs());
   for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
     if (stop_cb(nullptr)) {
       --epoch_idx;
index d371354..f314a1e 100644 (file)
@@ -268,7 +268,7 @@ public:
    * @brief     get current epoch_idx
    * @retval    current epoch_idx
    */
-  unsigned int getCurrentEpoch() override { return epoch_idx; };
+  unsigned int getCurrentEpoch() override;
 
   /**
    * @brief     Copy Neural Network