From be69f92c14c93928c1fc181bfb2a2b735a1e2263 Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Thu, 29 Sep 2022 18:19:26 +0900 Subject: [PATCH] [neuralnet] add log about training - Added log when start/finish training - Added log when get current epoch is called Signed-off-by: hyeonseok lee --- nntrainer/models/neuralnet.cpp | 28 +++++++++++++++++++++++----- nntrainer/models/neuralnet.h | 2 +- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index 9730b9e..c9ed19e 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -113,6 +113,13 @@ int NeuralNetwork::loadFromConfig(const std::string &config) { return ML_ERROR_NONE; } +unsigned int NeuralNetwork::getCurrentEpoch() { +#ifdef DEBUG + ml_logd("[NNTrainer] Current epoch: %d", epoch_idx); +#endif + return epoch_idx; +}; + void NeuralNetwork::setProperty(const std::vector &values) { auto left_props = loadProperties(values, model_props); setTrainConfig(left_props); @@ -745,14 +752,17 @@ int NeuralNetwork::train_run(std::function stop_cb) { return stat; }; - auto train_for_iteration = [this](RunStats &stat, DataBuffer &buffer) { + auto train_for_iteration = [this, stop_cb](RunStats &stat, + DataBuffer &buffer) { forwarding(true); backwarding(iter++); - std::cout << "#" << epoch_idx << "/" << getEpochs(); - ml_logi("# %d / %d", epoch_idx, getEpochs()); - auto loss = getLoss(); - buffer.displayProgress(stat.num_iterations, loss); + if (!stop_cb(nullptr)) { + std::cout << "#" << epoch_idx << "/" << getEpochs(); + ml_logi("# %d / %d", epoch_idx, getEpochs()); + auto loss = getLoss(); + buffer.displayProgress(stat.num_iterations, loss); + } }; auto update_train_stat = [this](RunStats &stat, @@ -774,6 +784,12 @@ int NeuralNetwork::train_run(std::function stop_cb) { << " - Training Loss: " << stat.loss; ml_logi("# %d / %d - Training Loss: %f", epoch_idx, getEpochs(), stat.loss); + ml_logd("[NNTrainer] Training epoch %d / %d finished successfully.", + epoch_idx, getEpochs()); + } else { + ml_logd("[NNTrainer] Training stopped by stop callback function during " + "epoch %d.", + epoch_idx); } }; @@ -822,6 +838,8 @@ int NeuralNetwork::train_run(std::function stop_cb) { }; auto epochs = getEpochs(); + ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.", + epoch_idx + 1, getEpochs()); for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) { if (stop_cb(nullptr)) { --epoch_idx; diff --git a/nntrainer/models/neuralnet.h b/nntrainer/models/neuralnet.h index d371354..f314a1e 100644 --- a/nntrainer/models/neuralnet.h +++ b/nntrainer/models/neuralnet.h @@ -268,7 +268,7 @@ public: * @brief get current epoch_idx * @retval current epoch_idx */ - unsigned int getCurrentEpoch() override { return epoch_idx; }; + unsigned int getCurrentEpoch() override; /** * @brief Copy Neural Network -- 2.7.4