[tensor_trainer] Save the stats information of the model to output tensors
authorhyunil park <hyunil46.park@samsung.com>
Thu, 25 May 2023 02:15:19 +0000 (11:15 +0900)
committerwooksong <wook16.song@samsung.com>
Wed, 31 May 2023 04:09:16 +0000 (13:09 +0900)
Whenever one of the epochs is completed, the stats information of the model being trained
in subplugin is stored in the output tensor.

- Apply getStatus to get status of subplugin
- Add new internal API(nnstreamer_trainer_notify_event()) to get event from subplugin
- Add function to write stats to output tensors

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/elements/gsttensor_trainer.c
gst/nnstreamer/elements/gsttensor_trainer.h

index ebd846e..3fb85d3 100644 (file)
  * ]|
  *
  * Total number of data to be received is 1000((num-training-samples + num-validation-samples) * epochs)
+ * 
+ * output tensors : dimensions=1:1:4, types=float64.
+ * values are training loss, training accuracy, validation loss and validation accuracy.
+ * -INFINITY value is stored if the value fetched from the sub-plugin is not greater than 0.
  */
 
 #ifdef HAVE_CONFIG_H
@@ -28,6 +32,7 @@
 #include <nnstreamer_util.h>
 #include "gsttensor_trainer.h"
 #include <unistd.h>
+#include <math.h>
 
 /**
  * @brief Default caps string for both sink and source pad.
@@ -56,6 +61,19 @@ GST_DEBUG_CATEGORY_STATIC (gst_tensor_trainer_debug);
 G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT);
 
 /**
+ * @brief Statistical from the model being trained
+ * An enum value indicates the value stored at the index of the output tensor.
+ */
+enum
+{
+  TRAINING_LOSS,
+  TRAINING_ACCURACY,
+  VALIDATION_LOSS,
+  VALIDATION_ACCURACY
+};
+#define MODEL_STATS_SIZE 4
+
+/**
  * @brief Default framework property value
  */
 #define DEFAULT_PROP_INPUT_LIST 1
@@ -63,14 +81,13 @@ G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT);
 #define DEFAULT_PROP_TRAIN_SAMPLES 0
 #define DEFAULT_PROP_VALID_SAMPLES 0
 #define DEFAULT_PROP_EPOCHS 1
-
 /**
  * @brief Default string property value
  */
 #define DEFAULT_STR_PROP_VALUE ""
 
 /**
- * @brief Default string property value
+ * @brief tensor_trainer properties
  */
 enum
 {
@@ -116,6 +133,8 @@ static gboolean gst_tensor_trainer_create_framework (GstTensorTrainer *
 static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
     guint index, gboolean is_input);
 static gboolean gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
+static void gst_tensor_trainer_create_event_notifier (GstTensorTrainer *
+    trainer);
 static void gst_tensor_trainer_train_model (GstTensorTrainer * trainer);
 static void gst_tensor_trainer_output_dimension (GstTensorTrainer * trainer);
 static void gst_tensor_trainer_output_type (GstTensorTrainer * trainer);
@@ -262,14 +281,17 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer)
   trainer->input_configured = FALSE;
   trainer->output_configured = FALSE;
   trainer->inputtype_configured = FALSE;
+  trainer->is_training_complete = FALSE;
+  trainer->is_epoch_complete = FALSE;
   trainer->total_push_data_cnt = 0;
 
   gst_tensors_config_init (&trainer->in_config);
   gst_tensors_config_init (&trainer->out_config);
 
-  g_cond_init (&trainer->training_complete_cond);
-  g_mutex_init (&trainer->trainer_lock);
-  trainer->prop.training_complete_cond = &trainer->training_complete_cond;
+  g_cond_init (&trainer->training_completion_cond);
+  g_mutex_init (&trainer->training_completion_lock);
+  g_cond_init (&trainer->epoch_completion_cond);
+  g_mutex_init (&trainer->epoch_completion_lock);
 
   gst_tensor_trainer_output_dimension (trainer);
   gst_tensor_trainer_output_type (trainer);
@@ -294,8 +316,10 @@ gst_tensor_trainer_finalize (GObject * object)
   gst_tensors_config_free (&trainer->in_config);
   gst_tensors_config_free (&trainer->out_config);
 
-  g_cond_clear (&trainer->training_complete_cond);
-  g_mutex_clear (&trainer->trainer_lock);
+  g_cond_clear (&trainer->training_completion_cond);
+  g_mutex_clear (&trainer->training_completion_lock);
+  g_cond_clear (&trainer->epoch_completion_cond);
+  g_mutex_clear (&trainer->epoch_completion_lock);
 
   if (trainer->fw_created && trainer->fw) {
     trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
@@ -435,6 +459,7 @@ gst_tensor_trainer_change_state (GstElement * element,
       if (!gst_tensor_trainer_create_model (trainer))
         goto state_change_failed;
 
+      gst_tensor_trainer_create_event_notifier (trainer);
       gst_tensor_trainer_train_model (trainer);
       break;
 
@@ -472,6 +497,52 @@ state_change_failed:
 }
 
 /**
+ * @brief Wait for epoch eompletion
+ */
+static void
+gst_tensor_trainer_wait_for_epoch_completion (GstTensorTrainer * trainer)
+{
+  g_return_if_fail (trainer != NULL);
+
+  g_mutex_lock (&trainer->epoch_completion_lock);
+  if (trainer->is_epoch_complete) {
+    /* It's already completed */
+    trainer->is_epoch_complete = FALSE;
+    g_mutex_unlock (&trainer->epoch_completion_lock);
+    return;
+  }
+
+  GST_INFO_OBJECT (trainer, "wait for epoch_completion_cond signal");
+  g_cond_wait (&trainer->epoch_completion_cond,
+      &trainer->epoch_completion_lock);
+  trainer->is_epoch_complete = FALSE;
+  g_mutex_unlock (&trainer->epoch_completion_lock);
+}
+
+/**
+ * @brief Check if current epochs is complete, 
+ * tensor_trainer wait for one of epochs to complete before getting the results from the subplugin
+ */
+static gboolean
+gst_tensor_trainer_epochs_is_complete (GstTensorTrainer * trainer)
+{
+  int required_sample;
+
+  g_return_val_if_fail (trainer != NULL, FALSE);
+  g_return_val_if_fail (trainer->fw != NULL, FALSE);
+  g_return_val_if_fail (&trainer->prop != NULL, FALSE);
+
+  required_sample =
+      trainer->prop.num_training_samples + trainer->prop.num_validation_samples;
+  if (trainer->total_push_data_cnt % required_sample != 0)
+    return FALSE;
+
+  gst_tensor_trainer_wait_for_epoch_completion (trainer);
+
+  return TRUE;
+}
+
+/**
  * @brief Chain function, this function does the actual processing.
  */
 static GstFlowReturn
@@ -494,13 +565,17 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
   GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
   GstTensorMetaInfo out_meta[NNS_TENSOR_SIZE_LIMIT];
 
+  double model_stats[MODEL_STATS_SIZE] =
+      { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
+  void *ptr;
+
   trainer = GST_TENSOR_TRAINER (parent);
 
   mem_blocks = gst_buffer_n_memory (inbuf);
   for (i = 0; i < mem_blocks; i++) {
     in_mem[i] = gst_buffer_peek_memory (inbuf, i);
     if (!gst_memory_map (in_mem[i], &in_info[i], GST_MAP_READ)) {
-      GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
+      GST_ERROR_OBJECT (trainer, "Could not map in_mem[%u] GstMemory", i);
       goto error;
     }
     in_flexible = gst_tensor_pad_caps_is_flexible (sinkpad);
@@ -521,11 +596,11 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
 
   /* Prepare tensor to push */
   /* Check number of input tensors */
-  GST_DEBUG_OBJECT (trainer, "num_tensors: %d",
+  GST_DEBUG_OBJECT (trainer, "num_tensors: %u",
       trainer->prop.input_meta.num_tensors);
   if (mem_blocks != trainer->prop.input_meta.num_tensors) {
-    GST_ERROR_OBJECT (trainer, "Invalid memory blocks(%d),"
-        "number of input tensors may be (%d)", mem_blocks,
+    GST_ERROR_OBJECT (trainer, "Invalid memory blocks(%u),"
+        "number of input tensors may be (%u)", mem_blocks,
         trainer->prop.input_meta.num_tensors);
     goto error;
   }
@@ -541,10 +616,10 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
     }
     /* Copy to data pointer */
     push_tensors[i] = in_tensors[i];
-    GST_INFO ("in_tensors[%d].size= %zd", i, in_tensors[i].size);
-    GST_INFO ("in_tensors[%d].data: %p", i, in_tensors[i].data);
-    GST_INFO ("push_tensors[%d].size= %zd", i, push_tensors[i].size);
-    GST_INFO ("push_tensors[%d].data: %p", i, push_tensors[i].data);
+    GST_INFO ("in_tensors[%u].size= %zd", i, in_tensors[i].size);
+    GST_INFO ("in_tensors[%u].data: %p", i, in_tensors[i].data);
+    GST_INFO ("push_tensors[%u].size= %zd", i, push_tensors[i].size);
+    GST_INFO ("push_tensors[%u].data: %p", i, push_tensors[i].data);
   }
 
   ret =
@@ -566,9 +641,7 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
       Scheduling with subplugin does not work.
    */
   if (trainer->total_push_data_cnt == 1
-      || trainer->total_push_data_cnt ==
-      trainer->prop.num_training_samples +
-      trainer->prop.num_validation_samples) {
+      || gst_tensor_trainer_epochs_is_complete (trainer)) {
 
     /* Prepare output tensor */
     for (i = 0; i < trainer->output_meta.num_tensors; i++) {
@@ -596,7 +669,7 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
       }
 
       if (!gst_memory_map (out_mem[i], &out_info[i], GST_MAP_WRITE)) {
-        GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
+        GST_ERROR_OBJECT (trainer, "Could not map in_mem[%u] GstMemory", i);
         goto error;
       }
 
@@ -610,13 +683,35 @@ gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
           goto error;
         }
       }
-#if 0
-      /** @todo Need to updatd out_tensors */
-      /* get loss, accuracy, val_loss, val_accuracy */
-      double data[4] = { 0, 0, 0, 0 };
+
+      ret =
+          trainer->fw->getStatus (trainer->fw, &trainer->prop,
+          trainer->privateData);
+      if (ret < 0) {
+        GST_ERROR_OBJECT (trainer, "Failed to Get status from sub-plugin.(%s).",
+            trainer->fw_name);
+        return GST_FLOW_ERROR;
+      }
+      /* If the value is invalid, it is already set by -INFINITY. */
+      if (trainer->prop.training_loss > 0)
+        model_stats[TRAINING_LOSS] = trainer->prop.training_loss;
+      if (trainer->prop.training_accuracy > 0)
+        model_stats[TRAINING_ACCURACY] = trainer->prop.training_accuracy;
+      if (trainer->prop.validation_loss > 0)
+        model_stats[VALIDATION_LOSS] = trainer->prop.validation_loss;
+      if (trainer->prop.validation_accuracy > 0)
+        model_stats[VALIDATION_ACCURACY] = trainer->prop.validation_accuracy;
+
+      GST_DEBUG_OBJECT (trainer,
+          "#%u/%u epochs [training_loss: %f, training_accuracy: %f, validation_loss: %f, validation_accuracy: %f]",
+          trainer->prop.epoch_count, trainer->prop.num_epochs,
+          model_stats[TRAINING_LOSS], model_stats[TRAINING_ACCURACY],
+          model_stats[VALIDATION_LOSS], model_stats[VALIDATION_ACCURACY]);
+
+      /* updatd out_tensors */
+      /* write training loss, training accuracy, validation loss, validation accuracy */
       ptr = out_info[i].data;
-      memcpy (ptr, data, sizeof (data));
-#endif
+      memcpy (ptr, model_stats, sizeof (model_stats));
     }
 
     /* Free out info */
@@ -695,6 +790,35 @@ gst_tensor_trainer_query_caps (GstTensorTrainer * trainer,
 }
 
 /**
+ * @brief Wait for training completion
+ */
+static gboolean
+gst_tensor_trainer_wait_for_training_completion (GstTensorTrainer * trainer)
+{
+  g_return_val_if_fail (trainer != NULL, FALSE);
+
+  g_mutex_lock (&trainer->training_completion_lock);
+  if (trainer->is_training_complete) {
+    /* It's already completed */
+    trainer->is_training_complete = FALSE;
+    g_mutex_unlock (&trainer->training_completion_lock);
+    return TRUE;
+  }
+
+  GST_INFO_OBJECT (trainer,
+      "got GST_EVENT_EOS event but training is not completed, state is %d",
+      GST_STATE (trainer));
+
+  GST_INFO_OBJECT (trainer, "wait for training_completion_cond signal");
+  g_cond_wait (&trainer->training_completion_cond,
+      &trainer->training_completion_lock);
+  trainer->is_training_complete = FALSE;
+  g_mutex_unlock (&trainer->training_completion_lock);
+
+  return TRUE;
+}
+
+/**
  * @brief Event handler for sink pad of tensor_trainer
  */
 static gboolean
@@ -702,7 +826,6 @@ gst_tensor_trainer_sink_event (GstPad * sinkpad, GstObject * parent,
     GstEvent * event)
 {
   GstTensorTrainer *trainer;
-  GstTensorTrainerFrameworkInfo info;
   trainer = GST_TENSOR_TRAINER (parent);
 
   GST_DEBUG_OBJECT (trainer, "Received %s event: %" GST_PTR_FORMAT,
@@ -710,17 +833,9 @@ gst_tensor_trainer_sink_event (GstPad * sinkpad, GstObject * parent,
 
   switch (GST_EVENT_TYPE (event)) {
     case GST_EVENT_EOS:
-      trainer->fw->getFrameworkInfo (trainer->fw, NULL, trainer->privateData,
-          &info);
-      if (!info.is_training_complete) {
-        GST_INFO_OBJECT (trainer,
-            "got GST_EVENT_EOS event but training is not completed, state is %d",
-            GST_STATE (trainer));
-        g_mutex_lock (&trainer->trainer_lock);
-        GST_INFO_OBJECT (trainer, "wait for training_complete_cond signal");
-        g_cond_wait (&trainer->training_complete_cond, &trainer->trainer_lock);
-        g_mutex_unlock (&trainer->trainer_lock);
-      }
+      if (gst_tensor_trainer_wait_for_training_completion (trainer))
+        GST_DEBUG_OBJECT (trainer, "training is completed in sub-plugin[%s]",
+            trainer->fw_name);
       break;
     case GST_EVENT_FLUSH_START:
       GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_START event");
@@ -1013,6 +1128,18 @@ gst_tensor_trainer_create_model (GstTensorTrainer * trainer)
 }
 
 /**
+ * @brief Create a event notifier
+ */
+static void
+gst_tensor_trainer_create_event_notifier (GstTensorTrainer * trainer)
+{
+  g_return_if_fail (trainer != NULL);
+  g_return_if_fail (trainer->fw != NULL);
+
+  trainer->notifier.notifier = (void *) trainer;
+}
+
+/**
  * @brief Train model
  */
 static void
@@ -1024,7 +1151,9 @@ gst_tensor_trainer_train_model (GstTensorTrainer * trainer)
   g_return_if_fail (trainer->fw->start != NULL);
 
   GST_DEBUG_OBJECT (trainer, "Start training model");
-  ret = trainer->fw->start (trainer->fw, &trainer->prop, trainer->privateData);
+  ret =
+      trainer->fw->start (trainer->fw, &trainer->prop, &trainer->notifier,
+      trainer->privateData);
   if (ret != 0) {
     GST_ERROR_OBJECT (trainer, "model training is failed");
   }
@@ -1117,3 +1246,42 @@ nnstreamer_trainer_exit (GstTensorTrainerFramework * ttsp)
 
   return unregister_subplugin (NNS_SUBPLUGIN_TRAINER, name);
 }
+
+/**
+ * @brief Trainer's sub-plugin may call this to send event.
+ * @param[in] notifier event notifier, sub-plugin must send events with this.
+ * @param[in] type event type
+ */
+void
+nnstreamer_trainer_notify_event (GstTensorTrainerEventNotifier * notifier,
+    GstTensorTrainerEventType type, void *data)
+{
+  GstTensorTrainer *trainer;
+  g_return_if_fail (notifier != NULL);
+  g_return_if_fail (type < TRAINER_EVENT_UNKNOWN || type > 0);
+  UNUSED (data);
+
+  trainer = (GstTensorTrainer *) notifier->notifier;
+  g_return_if_fail (GST_IS_TENSOR_TRAINER (trainer));
+
+  GST_DEBUG ("Received GstTensorTrainerEvent(%d)", type);
+
+  switch (type) {
+    case TRAINER_EVENT_EPOCH_COMPLETION:
+      g_mutex_lock (&trainer->epoch_completion_lock);
+      trainer->is_epoch_complete = TRUE;
+      GST_DEBUG ("send epoch_completion_cond signal");
+      g_cond_signal (&trainer->epoch_completion_cond);
+      g_mutex_unlock (&trainer->epoch_completion_lock);
+      break;
+    case TRAINER_EVENT_TRAINING_COMPLETION:
+      g_mutex_lock (&trainer->training_completion_lock);
+      trainer->is_training_complete = TRUE;
+      GST_DEBUG ("send training_completion_cond signal");
+      g_cond_signal (&trainer->training_completion_cond);
+      g_mutex_unlock (&trainer->training_completion_lock);
+      break;
+    default:
+      break;
+  }
+}
index d558c9e..9c43233 100644 (file)
@@ -56,21 +56,26 @@ struct _GstTensorTrainer
   gboolean input_configured;
   gboolean output_configured;
   gboolean inputtype_configured;
+  gboolean fw_created;
+  gboolean is_training_complete;
+  gboolean is_epoch_complete;
   unsigned int input_ranks[NNS_TENSOR_SIZE_LIMIT];
   unsigned int output_ranks[NNS_TENSOR_SIZE_LIMIT];
   GstTensorsInfo output_meta;
   GstTensorsConfig out_config;
   GstTensorsConfig in_config;
 
-  gint64 total_push_data_cnt;      /**< number of total push data */
-  gboolean fw_created;
+  guint total_push_data_cnt;      /**< number of total push data in one eposh */
 
   void *privateData; /**< NNFW plugin's private data is stored here */
-  const GstTensorTrainerFramework *fw;  /* for test, need to make */
+  const GstTensorTrainerFramework *fw; /**< Subplugin definition */
   GstTensorTrainerProperties prop; /**< NNFW plugin's properties */
+  GstTensorTrainerEventNotifier notifier; /**< Event notifier */
 
-  GMutex trainer_lock;
-  GCond training_complete_cond;
+  GMutex training_completion_lock;
+  GCond training_completion_cond;
+  GMutex epoch_completion_lock;
+  GCond epoch_completion_cond;
 };
 
 /**