[tensor_trainer] Add a function that generates dummy data
authorhyunil park <hyunil46.park@samsung.com>
Wed, 19 Jun 2024 03:40:38 +0000 (12:40 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Mon, 1 Jul 2024 01:44:17 +0000 (10:44 +0900)
When the pipeline stops and no data is input, the necessary data
for the current epoch is generated for the normal termination of nntrainer.
The data generation callback of nntrainer is no longer called
until all the necessary data for the current epoch are received.

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/elements/gsttensor_trainer.c
gst/nnstreamer/elements/gsttensor_trainer.h

index 9b8bde3..884f44a 100644 (file)
@@ -307,6 +307,7 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer)
   trainer->is_training_complete = FALSE;
   trainer->is_epoch_complete = FALSE;
   trainer->cur_epoch_data_cnt = 0;
+  trainer->required_sample = 0;
 
   gst_tensors_config_init (&trainer->in_config);
   gst_tensors_config_init (&trainer->out_config);
@@ -344,6 +345,11 @@ gst_tensor_trainer_finalize (GObject * object)
   g_cond_clear (&trainer->epoch_completion_cond);
   g_mutex_clear (&trainer->epoch_completion_lock);
 
+  if (trainer->dummy_data_thread) {
+    g_thread_join (trainer->dummy_data_thread);
+    trainer->dummy_data_thread = NULL;
+  }
+
   if (trainer->fw_created && trainer->fw) {
     trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
   }
@@ -489,6 +495,49 @@ gst_tensor_trainer_check_invalid_param (GstTensorTrainer * trainer)
 }
 
 /**
+ * @brief Dummy data generation thread
+ */
+static gpointer
+gst_tensor_trainer_dummy_data_generation_func (GstTensorTrainer * trainer)
+{
+  guint i;
+  gint ret = -1;
+  gpointer dummy_data[NNS_TENSOR_SIZE_LIMIT] = { NULL };
+  g_return_val_if_fail (trainer != NULL, NULL);
+
+  gst_tensor_trainer_stop_model_training (trainer);
+
+  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+    dummy_data[i] = g_malloc (trainer->push_tensors[i].size);
+    memset (dummy_data[i], 1, trainer->push_tensors[i].size);
+    trainer->push_tensors[i].data = dummy_data[i];
+  }
+
+  do {
+    GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
+        trainer->cur_epoch_data_cnt);
+    GST_INFO_OBJECT (trainer, "num_tensors=%d",
+        trainer->prop.input_meta.num_tensors);
+
+    ret =
+        trainer->fw->push_data (trainer->fw, &trainer->prop,
+        trainer->privateData, trainer->push_tensors);
+
+    if (ret < 0) {
+      GST_ERROR_OBJECT (trainer, "Failed to push dummy data");
+    } else {
+      trainer->cur_epoch_data_cnt++;
+    }
+  } while (trainer->required_sample > trainer->cur_epoch_data_cnt);
+
+  for (i = 0; i < trainer->output_meta.num_tensors; i++)
+    g_free (dummy_data[i]);
+
+  return NULL;
+}
+
+
+/**
  * @brief Change state of tensor_trainsink.
  */
 static GstStateChangeReturn
@@ -530,8 +579,17 @@ gst_tensor_trainer_change_state (GstElement * element,
   switch (transition) {
     case GST_STATE_CHANGE_PLAYING_TO_PAUSED:
       GST_INFO_OBJECT (trainer, "PLAYING_TO_PAUSED");
-      GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
-          trainer->cur_epoch_data_cnt);
+      /* need to generate dummy data */
+      if (!trainer->is_training_complete) {
+        if (!g_strcmp0 (trainer->fw_name, "nntrainer")) {
+          GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
+              trainer->cur_epoch_data_cnt);
+          trainer->dummy_data_thread =
+              g_thread_new ("dumy_data_generation_func",
+              (GThreadFunc) gst_tensor_trainer_dummy_data_generation_func,
+              trainer);
+        }
+      }
       break;
 
     case GST_STATE_CHANGE_PAUSED_TO_READY:
@@ -581,15 +639,13 @@ gst_tensor_trainer_wait_for_epoch_completion (GstTensorTrainer * trainer)
 static gboolean
 gst_tensor_trainer_epochs_is_complete (GstTensorTrainer * trainer)
 {
-  guint required_sample;
-
   g_return_val_if_fail (trainer != NULL, FALSE);
   g_return_val_if_fail (trainer->fw != NULL, FALSE);
   g_return_val_if_fail (&trainer->prop != NULL, FALSE);
 
-  required_sample =
+  trainer->required_sample =
       trainer->prop.num_training_samples + trainer->prop.num_validation_samples;
-  if (trainer->cur_epoch_data_cnt != required_sample)
+  if (trainer->cur_epoch_data_cnt != trainer->required_sample)
     return FALSE;
 
   gst_tensor_trainer_wait_for_epoch_completion (trainer);
@@ -615,7 +671,6 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
   GstMemory *out_mem[NNS_TENSOR_SIZE_LIMIT] = { 0, };
   GstMapInfo out_info[NNS_TENSOR_SIZE_LIMIT];
   GstTensorMemory in_tensors[NNS_TENSOR_SIZE_LIMIT];
-  GstTensorMemory push_tensors[NNS_TENSOR_SIZE_LIMIT];
   GstTensorMemory out_tensors[NNS_TENSOR_SIZE_LIMIT];
   GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
   GstTensorMetaInfo out_meta[NNS_TENSOR_SIZE_LIMIT];
@@ -702,16 +757,16 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
       goto error;
     }
     /* Copy to data pointer */
-    push_tensors[i] = in_tensors[i];
+    trainer->push_tensors[i] = in_tensors[i];
     GST_INFO ("in_tensors[%u].size= %zd", i, in_tensors[i].size);
     GST_INFO ("in_tensors[%u].data: %p", i, in_tensors[i].data);
-    GST_INFO ("push_tensors[%u].size= %zd", i, push_tensors[i].size);
-    GST_INFO ("push_tensors[%u].data: %p", i, push_tensors[i].data);
+    GST_INFO ("push_tensors[%u].size= %zd", i, trainer->push_tensors[i].size);
+    GST_INFO ("push_tensors[%u].data: %p", i, trainer->push_tensors[i].data);
   }
 
   ret =
       trainer->fw->push_data (trainer->fw, &trainer->prop, trainer->privateData,
-      push_tensors);
+      trainer->push_tensors);
   trainer->cur_epoch_data_cnt++;
 
   /* Free in info */
@@ -1260,8 +1315,6 @@ gst_tensor_trainer_stop_model_training (GstTensorTrainer * trainer)
   g_return_if_fail (trainer != NULL);
   g_return_if_fail (trainer->fw != NULL);
   g_return_if_fail (trainer->fw->stop != NULL);
-  g_return_if_fail (trainer->ready_to_complete_training);
-
   GST_DEBUG_OBJECT (trainer, "Stop model training");
   ret = trainer->fw->stop (trainer->fw, &trainer->prop, &trainer->privateData);
   if (ret != 0) {
index faefcbd..9da6260 100644 (file)
@@ -60,10 +60,13 @@ struct _GstTensorTrainer
   gboolean ready_to_complete_training;
   unsigned int input_ranks[NNS_TENSOR_SIZE_LIMIT];
   unsigned int output_ranks[NNS_TENSOR_SIZE_LIMIT];
+
+  GstTensorMemory push_tensors[NNS_TENSOR_SIZE_LIMIT];
   GstTensorsInfo output_meta;
   GstTensorsConfig out_config;
   GstTensorsConfig in_config;
 
+  guint required_sample;
   guint cur_epoch_data_cnt;      /**< number of total push data in one eposh */
 
   void *privateData; /**< NNFW plugin's private data is stored here */
@@ -75,6 +78,8 @@ struct _GstTensorTrainer
   GCond training_completion_cond;
   GMutex epoch_completion_lock;
   GCond epoch_completion_cond;
+
+  GThread *dummy_data_thread; /**< Dummy data generation thread */
 };
 
 /**