return 0;
}
+static int nntrainer_model_stop_training(const GstTensorTrainerFramework *fw,
+ const GstTensorTrainerProperties *prop,
+ void **private_data) {
+ NNTrainer::InputTensorsInfo *train_data = nullptr, *valid_data = nullptr;
+ NNTrainer::NNTrainerTrain *nntrainer =
+ static_cast<NNTrainer::NNTrainerTrain *>(*private_data);
+ UNUSED(fw);
+ ml_logd("<called>");
+
+ if (!nntrainer)
+ return -1;
+
+ nntrainer->stop_model_training = TRUE;
+
+ ml_logd("<leave>");
+ return 0;
+}
+
bool stop_cb(void *user_data) {
bool *ret = reinterpret_cast<bool *>(user_data);
+ ml_logd("<called> %d", *ret);
return *ret;
}
void NNTrainer::NNTrainerTrain::trainModel() {
pid_t pid = getpid();
pid_t tid = syscall(SYS_gettid);
- bool stop = false;
+ stop_model_training = false;
ml_logd("<called>");
ml_logd("pid[%d], tid[%d]", pid, tid);
NNTrainer::NNTrainerTrain *nntrainer = GetNNTrainerTrain();
try {
- model->train({}, stop_cb, &stop, epoch_complete_cb, nntrainer);
+ model->train({}, stop_cb, &stop_model_training, epoch_complete_cb,
+ nntrainer);
training_loss = model->getTrainingLoss();
validation_loss = model->getValidationLoss();
getRunStats();
validation_loss(0),
num_push_data(0),
model_config(_model_config),
- notifier(nullptr) {
+ notifier(nullptr),
+ stop_model_training(FALSE) {
ml_logd("<called>");
getNNStreamerProperties(prop);
createModel();
.create = nntrainer_model_construct,
.destroy = nntrainer_model_destructor,
.start = nntrainer_model_start_training,
+ .stop = nntrainer_model_stop_training,
.push_data = nntrainer_model_push_data,
.getStatus = nntrainer_getStatus,
.getFrameworkInfo = nntrainer_getFrameworkInfo};