* @param[in] values hyper parameters
* @param[in] stop_cb callback function to decide stop training or not
* ~~~~~
- * @a user_data user_data to be used in stop_cb
+ * @a stop_user_data user_data to be used in stop_cb
* @a bool true if stop the training
* ~~~~~
+ * @param[in] epoch_complete_cb Called the end of an epoch
+ * ~~~~~
+ * @a epoch_user_data user_data to be used in epoch_complete_cb
+ * ~~~~~
* @retval #ML_ERROR_NONE Successful.
* @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
* @details This function accepts vector of properties in the format -
*/
virtual int train(const std::vector<std::string> &values = {},
std::function<bool(void *)> stop_cb =
- [](void *user_data) { return false; },
- void *user_data = nullptr) = 0;
+ [](void *stop_user_data) { return false; },
+ void *stop_user_data = nullptr,
+ std::function<void(void *)> epoch_complete_cb =
+ [](void *epoch_user_data) { return false; },
+ void *epoch_user_data = nullptr) = 0;
/**
* @brief Run Model train with callback function by user
}
int NeuralNetwork::train(const std::vector<std::string> &values,
- std::function<bool(void *)> stop_cb, void *user_data) {
+ std::function<bool(void *)> stop_cb,
+ void *stop_user_data,
+ std::function<void(void *)> epoch_complete_cb,
+ void *epoch_user_data) {
int status = ML_ERROR_NONE;
if (data_buffers[static_cast<int>(DatasetModeType::MODE_TRAIN)] == nullptr) {
status = allocate(ExecutionMode::TRAIN);
NN_RETURN_STATUS();
- status = train_run(stop_cb, user_data);
+ status =
+ train_run(stop_cb, stop_user_data, epoch_complete_cb, epoch_user_data);
NN_RETURN_STATUS();
/**
/**
* @brief Run NeuralNetwork train with callback function by user
*/
-int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb,
- void *user_data) {
+int NeuralNetwork::train_run(
+ std::function<bool(void *userdata)> stop_cb, void *stop_user_data,
+ std::function<void(void *userdata)> epoch_complete_cb,
+ void *epoch_user_data) {
int status = ML_ERROR_NONE;
if (!std::get<props::ContinueTrain>(model_flex_props)) {
return stat;
};
- auto train_for_iteration = [this, stop_cb, user_data](RunStats &stat,
- DataBuffer &buffer) {
- forwarding(true, stop_cb, user_data);
- backwarding(iter++, stop_cb, user_data);
+ auto train_for_iteration =
+ [this, stop_cb, stop_user_data](RunStats &stat, DataBuffer &buffer) {
+ forwarding(true, stop_cb, stop_user_data);
+ backwarding(iter++, stop_cb, stop_user_data);
- // To avoid unconsidered memory leak, we need to clear the cache
- model_graph.flushCache();
+ // To avoid unconsidered memory leak, we need to clear the cache
+ model_graph.flushCache();
- if (!stop_cb(user_data)) {
- std::cout << "#" << epoch_idx << "/" << getEpochs();
- ml_logi("# %d / %d", epoch_idx, getEpochs());
- auto loss = getLoss();
- buffer.displayProgress(stat.num_iterations, loss);
- }
- };
+ if (!stop_cb(stop_user_data)) {
+ std::cout << "#" << epoch_idx << "/" << getEpochs();
+ ml_logi("# %d / %d", epoch_idx, getEpochs());
+ auto loss = getLoss();
+ buffer.displayProgress(stat.num_iterations, loss);
+ }
+ };
auto update_train_stat = [this](RunStats &stat,
const std::vector<Tensor> &outputs,
stat.num_iterations++;
};
- auto train_epoch_end = [this, stop_cb, user_data](RunStats &stat,
- DataBuffer &buffer) {
+ auto train_epoch_end = [this, stop_cb, stop_user_data](RunStats &stat,
+ DataBuffer &buffer) {
if (stat.num_iterations != 0) {
stat.loss /= static_cast<float>(stat.num_iterations);
} else {
return;
}
auto &save_path = std::get<props::SavePath>(model_flex_props);
- if (!stop_cb(user_data)) {
+ if (!stop_cb(stop_user_data)) {
if (!save_path.empty()) {
save(save_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
}
}
};
- auto eval_for_iteration = [this, batch_size, stop_cb,
- user_data](RunStats &stat, DataBuffer &buffer) {
- forwarding(false, stop_cb, user_data);
+ auto eval_for_iteration = [this, batch_size, stop_cb, stop_user_data](
+ RunStats &stat, DataBuffer &buffer) {
+ forwarding(false, stop_cb, stop_user_data);
};
auto update_eval_stat = [batch_size, &update_train_stat](
ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
epoch_idx + 1, getEpochs());
for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
- if (stop_cb(user_data)) {
+ if (stop_cb(stop_user_data)) {
--epoch_idx;
break;
}
update_eval_stat, eval_epoch_end, validation);
}
std::cout << '\n';
+ epoch_complete_cb(epoch_user_data);
}
PROFILE_MEM_ANNOTATE("TRAIN END");
* @param[in] values hyper parameters
* @param[in] stop_cb callback function to decide stop training or not
* ~~~~~
- * @a user_data user_data to be used in stop_cb
+ * @a stop_user_data user_data to be used in stop_cb
* @a bool true if stop the training
* ~~~~~
+ * @param[in] epoch_complete_cb Called the end of an epoch.
+ * @a epoch_user_data user_data to be used in epoch_complete_cb
+ * ~~~~~
* @retval #ML_ERROR_NONE Successful.
* @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
*/
int train(const std::vector<std::string> &values = {},
std::function<bool(void *)> stop_cb =
- [](void *user_data) { return false; },
- void *user_data = nullptr) override;
+ [](void *stop_user_data) { return false; },
+ void *stop_user_data = nullptr,
+ std::function<void(void *)> epoch_complete_cb =
+ [](void *epoch_user_data) { return false; },
+ void *epoch_user_data = nullptr) override;
/**
* @brief Run NeuralNetwork inference
/**
* @brief Run NeuralNetwork train
* @param[in] stop_cb callback function to decide stop training or not
+ * @param[in] epoch_complete_cb Called the end of an epoch.
* @retval #ML_ERROR_NONE Successful.
* @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
*/
int train_run(std::function<bool(void *)> stop_cb =
[](void *) { return false; },
- void *user_data = nullptr);
+ void *user_data = nullptr,
+ std::function<void(void *)> epoch_complete_cb =
+ [](void *) { return false; },
+ void *data = nullptr);
/**
* @brief Swap function for the class