[sub-plugin] Add function to stop model training
authorhyunil park <hyunil46.park@samsung.com>
Tue, 8 Aug 2023 02:42:30 +0000 (11:42 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 28 Aug 2023 04:27:15 +0000 (13:27 +0900)
nnstreamer tensor_trainer call this function to stop model training

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
nnstreamer/tensor_trainer/tensor_trainer_nntrainer.cc
nnstreamer/tensor_trainer/tensor_trainer_nntrainer.hh

index c1286a96a09133f2f81a3c2ce53955552f2d1fb2..9b06345fc086f1ccfb4215065389207ecf29b8aa 100644 (file)
@@ -502,8 +502,27 @@ static int nntrainer_model_start_training(
   return 0;
 }
 
+static int nntrainer_model_stop_training(const GstTensorTrainerFramework *fw,
+                                         const GstTensorTrainerProperties *prop,
+                                         void **private_data) {
+  NNTrainer::InputTensorsInfo *train_data = nullptr, *valid_data = nullptr;
+  NNTrainer::NNTrainerTrain *nntrainer =
+    static_cast<NNTrainer::NNTrainerTrain *>(*private_data);
+  UNUSED(fw);
+  ml_logd("<called>");
+
+  if (!nntrainer)
+    return -1;
+
+  nntrainer->stop_model_training = TRUE;
+
+  ml_logd("<leave>");
+  return 0;
+}
+
 bool stop_cb(void *user_data) {
   bool *ret = reinterpret_cast<bool *>(user_data);
+  ml_logd("<called> %d", *ret);
   return *ret;
 }
 
@@ -527,7 +546,7 @@ void epoch_complete_cb(void *user_data) {
 void NNTrainer::NNTrainerTrain::trainModel() {
   pid_t pid = getpid();
   pid_t tid = syscall(SYS_gettid);
-  bool stop = false;
+  stop_model_training = false;
 
   ml_logd("<called>");
   ml_logd("pid[%d], tid[%d]", pid, tid);
@@ -542,7 +561,8 @@ void NNTrainer::NNTrainerTrain::trainModel() {
   NNTrainer::NNTrainerTrain *nntrainer = GetNNTrainerTrain();
 
   try {
-    model->train({}, stop_cb, &stop, epoch_complete_cb, nntrainer);
+    model->train({}, stop_cb, &stop_model_training, epoch_complete_cb,
+                 nntrainer);
     training_loss = model->getTrainingLoss();
     validation_loss = model->getValidationLoss();
     getRunStats();
@@ -615,7 +635,8 @@ NNTrainer::NNTrainerTrain::NNTrainerTrain(
   validation_loss(0),
   num_push_data(0),
   model_config(_model_config),
-  notifier(nullptr) {
+  notifier(nullptr),
+  stop_model_training(FALSE) {
   ml_logd("<called>");
   getNNStreamerProperties(prop);
   createModel();
@@ -710,6 +731,7 @@ static GstTensorTrainerFramework NNS_Trainer_support_nntrainer = {
   .create = nntrainer_model_construct,
   .destroy = nntrainer_model_destructor,
   .start = nntrainer_model_start_training,
+  .stop = nntrainer_model_stop_training,
   .push_data = nntrainer_model_push_data,
   .getStatus = nntrainer_getStatus,
   .getFrameworkInfo = nntrainer_getFrameworkInfo};
index 74aa21988d3646f3d94c3cdc2f22faf75c7d6b24..d8521b1e4a5eba478c67bf4847430ee4e114fdb7 100644 (file)
@@ -113,6 +113,7 @@ public:
   ml::train::RunStats valid_stats;
 
   GstTensorTrainerEventNotifier *notifier; /**< a handle of event notify */
+  bool stop_model_training;
 
 private:
   std::unique_ptr<ml::train::Model> model;