return ML_ERROR_NONE;
}
+unsigned int NeuralNetwork::getCurrentEpoch() {
+#ifdef DEBUG
+ ml_logd("[NNTrainer] Current epoch: %d", epoch_idx);
+#endif
+ return epoch_idx;
+};
+
void NeuralNetwork::setProperty(const std::vector<std::string> &values) {
auto left_props = loadProperties(values, model_props);
setTrainConfig(left_props);
return stat;
};
- auto train_for_iteration = [this](RunStats &stat, DataBuffer &buffer) {
+ auto train_for_iteration = [this, stop_cb](RunStats &stat,
+ DataBuffer &buffer) {
forwarding(true);
backwarding(iter++);
- std::cout << "#" << epoch_idx << "/" << getEpochs();
- ml_logi("# %d / %d", epoch_idx, getEpochs());
- auto loss = getLoss();
- buffer.displayProgress(stat.num_iterations, loss);
+ if (!stop_cb(nullptr)) {
+ std::cout << "#" << epoch_idx << "/" << getEpochs();
+ ml_logi("# %d / %d", epoch_idx, getEpochs());
+ auto loss = getLoss();
+ buffer.displayProgress(stat.num_iterations, loss);
+ }
};
auto update_train_stat = [this](RunStats &stat,
<< " - Training Loss: " << stat.loss;
ml_logi("# %d / %d - Training Loss: %f", epoch_idx, getEpochs(),
stat.loss);
+ ml_logd("[NNTrainer] Training epoch %d / %d finished successfully.",
+ epoch_idx, getEpochs());
+ } else {
+ ml_logd("[NNTrainer] Training stopped by stop callback function during "
+ "epoch %d.",
+ epoch_idx);
}
};
};
auto epochs = getEpochs();
+ ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
+ epoch_idx + 1, getEpochs());
for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
if (stop_cb(nullptr)) {
--epoch_idx;