[API][trainer] Add number of epoch and epoch count
authorhyunil park <hyunil46.park@samsung.com>
Thu, 5 Jan 2023 05:13:00 +0000 (14:13 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 12 Jan 2023 10:25:29 +0000 (19:25 +0900)
- Add num_epochs to inform the total amount of data than sub-plugin can receive
- Add epoch count to provide currently epoch information for sub-plugin to tensor_trainer

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/include/nnstreamer_plugin_api_trainer.h

index 4ff99d9..407c4d0 100644 (file)
@@ -37,6 +37,7 @@ typedef struct _GstTensorTrainerProperties
   int64_t num_labels;    /**< The number of label lists, the label is where framework receive the class to train the model, num_labels indicates how many labels there are. */
   int64_t num_train_samples;    /**< The number of train sample used to train the model. */
   int64_t num_valid_samples;    /**< The number of valid sample used to train the model. */
+  int64_t num_epochs;    /**< The number of repetition of total train and valid sample. subplugin must receive total samples((num_train_samples + num_valid_samples) * num_epochs) */
 
   GCond *train_complete_cond;    /**< Tensor trainer wait when receive EOS before model training is complete, subplugin should send signal when model train is complete. */
 } GstTensorTrainerProperties;
@@ -50,6 +51,7 @@ typedef struct _GstTensorTrainerFrameworkInfo
 {
   const char *name;    /**< Name of the neural network framework, searchable by FRAMEWORK property. */
   gboolean  train_complete;  /**< Check if train is complete */
+  int64_t epoch_cnt;    /**< Number of currently completed epochs */
 } GstTensorTrainerFrameworkInfo;
 
 typedef struct _GstTensorTrainerFramework GstTensorTrainerFramework;