* ]|
*
* Total number of data to be received is 1000((num-training-samples + num-validation-samples) * epochs)
+ *
+ * output tensors : dimensions=1:1:4, types=float64.
+ * values are training loss, training accuracy, validation loss and validation accuracy.
+ * -INFINITY value is stored if the value fetched from the sub-plugin is not greater than 0.
*/
#ifdef HAVE_CONFIG_H
#include <nnstreamer_util.h>
#include "gsttensor_trainer.h"
#include <unistd.h>
+#include <math.h>
/**
* @brief Default caps string for both sink and source pad.
G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT);
/**
+ * @brief Statistical from the model being trained
+ * An enum value indicates the value stored at the index of the output tensor.
+ */
+enum
+{
+ TRAINING_LOSS,
+ TRAINING_ACCURACY,
+ VALIDATION_LOSS,
+ VALIDATION_ACCURACY
+};
+#define MODEL_STATS_SIZE 4
+
+/**
* @brief Default framework property value
*/
#define DEFAULT_PROP_INPUT_LIST 1
#define DEFAULT_PROP_TRAIN_SAMPLES 0
#define DEFAULT_PROP_VALID_SAMPLES 0
#define DEFAULT_PROP_EPOCHS 1
-
/**
* @brief Default string property value
*/
#define DEFAULT_STR_PROP_VALUE ""
/**
- * @brief Default string property value
+ * @brief tensor_trainer properties
*/
enum
{
static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
guint index, gboolean is_input);
static gboolean gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
+static void gst_tensor_trainer_create_event_notifier (GstTensorTrainer *
+ trainer);
static void gst_tensor_trainer_train_model (GstTensorTrainer * trainer);
static void gst_tensor_trainer_output_dimension (GstTensorTrainer * trainer);
static void gst_tensor_trainer_output_type (GstTensorTrainer * trainer);
trainer->input_configured = FALSE;
trainer->output_configured = FALSE;
trainer->inputtype_configured = FALSE;
+ trainer->is_training_complete = FALSE;
+ trainer->is_epoch_complete = FALSE;
trainer->total_push_data_cnt = 0;
gst_tensors_config_init (&trainer->in_config);
gst_tensors_config_init (&trainer->out_config);
- g_cond_init (&trainer->training_complete_cond);
- g_mutex_init (&trainer->trainer_lock);
- trainer->prop.training_complete_cond = &trainer->training_complete_cond;
+ g_cond_init (&trainer->training_completion_cond);
+ g_mutex_init (&trainer->training_completion_lock);
+ g_cond_init (&trainer->epoch_completion_cond);
+ g_mutex_init (&trainer->epoch_completion_lock);
gst_tensor_trainer_output_dimension (trainer);
gst_tensor_trainer_output_type (trainer);
gst_tensors_config_free (&trainer->in_config);
gst_tensors_config_free (&trainer->out_config);
- g_cond_clear (&trainer->training_complete_cond);
- g_mutex_clear (&trainer->trainer_lock);
+ g_cond_clear (&trainer->training_completion_cond);
+ g_mutex_clear (&trainer->training_completion_lock);
+ g_cond_clear (&trainer->epoch_completion_cond);
+ g_mutex_clear (&trainer->epoch_completion_lock);
if (trainer->fw_created && trainer->fw) {
trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
if (!gst_tensor_trainer_create_model (trainer))
goto state_change_failed;
+ gst_tensor_trainer_create_event_notifier (trainer);
gst_tensor_trainer_train_model (trainer);
break;
}
/**
+ * @brief Wait for epoch eompletion
+ */
+static void
+gst_tensor_trainer_wait_for_epoch_completion (GstTensorTrainer * trainer)
+{
+ g_return_if_fail (trainer != NULL);
+
+ g_mutex_lock (&trainer->epoch_completion_lock);
+ if (trainer->is_epoch_complete) {
+ /* It's already completed */
+ trainer->is_epoch_complete = FALSE;
+ g_mutex_unlock (&trainer->epoch_completion_lock);
+ return;
+ }
+
+ GST_INFO_OBJECT (trainer, "wait for epoch_completion_cond signal");
+ g_cond_wait (&trainer->epoch_completion_cond,
+ &trainer->epoch_completion_lock);
+ trainer->is_epoch_complete = FALSE;
+ g_mutex_unlock (&trainer->epoch_completion_lock);
+}
+
+/**
+ * @brief Check if current epochs is complete,
+ * tensor_trainer wait for one of epochs to complete before getting the results from the subplugin
+ */
+static gboolean
+gst_tensor_trainer_epochs_is_complete (GstTensorTrainer * trainer)
+{
+ int required_sample;
+
+ g_return_val_if_fail (trainer != NULL, FALSE);
+ g_return_val_if_fail (trainer->fw != NULL, FALSE);
+ g_return_val_if_fail (&trainer->prop != NULL, FALSE);
+
+ required_sample =
+ trainer->prop.num_training_samples + trainer->prop.num_validation_samples;
+ if (trainer->total_push_data_cnt % required_sample != 0)
+ return FALSE;
+
+ gst_tensor_trainer_wait_for_epoch_completion (trainer);
+
+ return TRUE;
+}
+
+/**
* @brief Chain function, this function does the actual processing.
*/
static GstFlowReturn
GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
GstTensorMetaInfo out_meta[NNS_TENSOR_SIZE_LIMIT];
+ double model_stats[MODEL_STATS_SIZE] =
+ { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
+ void *ptr;
+
trainer = GST_TENSOR_TRAINER (parent);
mem_blocks = gst_buffer_n_memory (inbuf);
for (i = 0; i < mem_blocks; i++) {
in_mem[i] = gst_buffer_peek_memory (inbuf, i);
if (!gst_memory_map (in_mem[i], &in_info[i], GST_MAP_READ)) {
- GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
+ GST_ERROR_OBJECT (trainer, "Could not map in_mem[%u] GstMemory", i);
goto error;
}
in_flexible = gst_tensor_pad_caps_is_flexible (sinkpad);
/* Prepare tensor to push */
/* Check number of input tensors */
- GST_DEBUG_OBJECT (trainer, "num_tensors: %d",
+ GST_DEBUG_OBJECT (trainer, "num_tensors: %u",
trainer->prop.input_meta.num_tensors);
if (mem_blocks != trainer->prop.input_meta.num_tensors) {
- GST_ERROR_OBJECT (trainer, "Invalid memory blocks(%d),"
- "number of input tensors may be (%d)", mem_blocks,
+ GST_ERROR_OBJECT (trainer, "Invalid memory blocks(%u),"
+ "number of input tensors may be (%u)", mem_blocks,
trainer->prop.input_meta.num_tensors);
goto error;
}
}
/* Copy to data pointer */
push_tensors[i] = in_tensors[i];
- GST_INFO ("in_tensors[%d].size= %zd", i, in_tensors[i].size);
- GST_INFO ("in_tensors[%d].data: %p", i, in_tensors[i].data);
- GST_INFO ("push_tensors[%d].size= %zd", i, push_tensors[i].size);
- GST_INFO ("push_tensors[%d].data: %p", i, push_tensors[i].data);
+ GST_INFO ("in_tensors[%u].size= %zd", i, in_tensors[i].size);
+ GST_INFO ("in_tensors[%u].data: %p", i, in_tensors[i].data);
+ GST_INFO ("push_tensors[%u].size= %zd", i, push_tensors[i].size);
+ GST_INFO ("push_tensors[%u].data: %p", i, push_tensors[i].data);
}
ret =
Scheduling with subplugin does not work.
*/
if (trainer->total_push_data_cnt == 1
- || trainer->total_push_data_cnt ==
- trainer->prop.num_training_samples +
- trainer->prop.num_validation_samples) {
+ || gst_tensor_trainer_epochs_is_complete (trainer)) {
/* Prepare output tensor */
for (i = 0; i < trainer->output_meta.num_tensors; i++) {
}
if (!gst_memory_map (out_mem[i], &out_info[i], GST_MAP_WRITE)) {
- GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
+ GST_ERROR_OBJECT (trainer, "Could not map in_mem[%u] GstMemory", i);
goto error;
}
goto error;
}
}
-#if 0
- /** @todo Need to updatd out_tensors */
- /* get loss, accuracy, val_loss, val_accuracy */
- double data[4] = { 0, 0, 0, 0 };
+
+ ret =
+ trainer->fw->getStatus (trainer->fw, &trainer->prop,
+ trainer->privateData);
+ if (ret < 0) {
+ GST_ERROR_OBJECT (trainer, "Failed to Get status from sub-plugin.(%s).",
+ trainer->fw_name);
+ return GST_FLOW_ERROR;
+ }
+ /* If the value is invalid, it is already set by -INFINITY. */
+ if (trainer->prop.training_loss > 0)
+ model_stats[TRAINING_LOSS] = trainer->prop.training_loss;
+ if (trainer->prop.training_accuracy > 0)
+ model_stats[TRAINING_ACCURACY] = trainer->prop.training_accuracy;
+ if (trainer->prop.validation_loss > 0)
+ model_stats[VALIDATION_LOSS] = trainer->prop.validation_loss;
+ if (trainer->prop.validation_accuracy > 0)
+ model_stats[VALIDATION_ACCURACY] = trainer->prop.validation_accuracy;
+
+ GST_DEBUG_OBJECT (trainer,
+ "#%u/%u epochs [training_loss: %f, training_accuracy: %f, validation_loss: %f, validation_accuracy: %f]",
+ trainer->prop.epoch_count, trainer->prop.num_epochs,
+ model_stats[TRAINING_LOSS], model_stats[TRAINING_ACCURACY],
+ model_stats[VALIDATION_LOSS], model_stats[VALIDATION_ACCURACY]);
+
+ /* updatd out_tensors */
+ /* write training loss, training accuracy, validation loss, validation accuracy */
ptr = out_info[i].data;
- memcpy (ptr, data, sizeof (data));
-#endif
+ memcpy (ptr, model_stats, sizeof (model_stats));
}
/* Free out info */
}
/**
+ * @brief Wait for training completion
+ */
+static gboolean
+gst_tensor_trainer_wait_for_training_completion (GstTensorTrainer * trainer)
+{
+ g_return_val_if_fail (trainer != NULL, FALSE);
+
+ g_mutex_lock (&trainer->training_completion_lock);
+ if (trainer->is_training_complete) {
+ /* It's already completed */
+ trainer->is_training_complete = FALSE;
+ g_mutex_unlock (&trainer->training_completion_lock);
+ return TRUE;
+ }
+
+ GST_INFO_OBJECT (trainer,
+ "got GST_EVENT_EOS event but training is not completed, state is %d",
+ GST_STATE (trainer));
+
+ GST_INFO_OBJECT (trainer, "wait for training_completion_cond signal");
+ g_cond_wait (&trainer->training_completion_cond,
+ &trainer->training_completion_lock);
+ trainer->is_training_complete = FALSE;
+ g_mutex_unlock (&trainer->training_completion_lock);
+
+ return TRUE;
+}
+
+/**
* @brief Event handler for sink pad of tensor_trainer
*/
static gboolean
GstEvent * event)
{
GstTensorTrainer *trainer;
- GstTensorTrainerFrameworkInfo info;
trainer = GST_TENSOR_TRAINER (parent);
GST_DEBUG_OBJECT (trainer, "Received %s event: %" GST_PTR_FORMAT,
switch (GST_EVENT_TYPE (event)) {
case GST_EVENT_EOS:
- trainer->fw->getFrameworkInfo (trainer->fw, NULL, trainer->privateData,
- &info);
- if (!info.is_training_complete) {
- GST_INFO_OBJECT (trainer,
- "got GST_EVENT_EOS event but training is not completed, state is %d",
- GST_STATE (trainer));
- g_mutex_lock (&trainer->trainer_lock);
- GST_INFO_OBJECT (trainer, "wait for training_complete_cond signal");
- g_cond_wait (&trainer->training_complete_cond, &trainer->trainer_lock);
- g_mutex_unlock (&trainer->trainer_lock);
- }
+ if (gst_tensor_trainer_wait_for_training_completion (trainer))
+ GST_DEBUG_OBJECT (trainer, "training is completed in sub-plugin[%s]",
+ trainer->fw_name);
break;
case GST_EVENT_FLUSH_START:
GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_START event");
}
/**
+ * @brief Create a event notifier
+ */
+static void
+gst_tensor_trainer_create_event_notifier (GstTensorTrainer * trainer)
+{
+ g_return_if_fail (trainer != NULL);
+ g_return_if_fail (trainer->fw != NULL);
+
+ trainer->notifier.notifier = (void *) trainer;
+}
+
+/**
* @brief Train model
*/
static void
g_return_if_fail (trainer->fw->start != NULL);
GST_DEBUG_OBJECT (trainer, "Start training model");
- ret = trainer->fw->start (trainer->fw, &trainer->prop, trainer->privateData);
+ ret =
+ trainer->fw->start (trainer->fw, &trainer->prop, &trainer->notifier,
+ trainer->privateData);
if (ret != 0) {
GST_ERROR_OBJECT (trainer, "model training is failed");
}
return unregister_subplugin (NNS_SUBPLUGIN_TRAINER, name);
}
+
+/**
+ * @brief Trainer's sub-plugin may call this to send event.
+ * @param[in] notifier event notifier, sub-plugin must send events with this.
+ * @param[in] type event type
+ */
+void
+nnstreamer_trainer_notify_event (GstTensorTrainerEventNotifier * notifier,
+ GstTensorTrainerEventType type, void *data)
+{
+ GstTensorTrainer *trainer;
+ g_return_if_fail (notifier != NULL);
+ g_return_if_fail (type < TRAINER_EVENT_UNKNOWN || type > 0);
+ UNUSED (data);
+
+ trainer = (GstTensorTrainer *) notifier->notifier;
+ g_return_if_fail (GST_IS_TENSOR_TRAINER (trainer));
+
+ GST_DEBUG ("Received GstTensorTrainerEvent(%d)", type);
+
+ switch (type) {
+ case TRAINER_EVENT_EPOCH_COMPLETION:
+ g_mutex_lock (&trainer->epoch_completion_lock);
+ trainer->is_epoch_complete = TRUE;
+ GST_DEBUG ("send epoch_completion_cond signal");
+ g_cond_signal (&trainer->epoch_completion_cond);
+ g_mutex_unlock (&trainer->epoch_completion_lock);
+ break;
+ case TRAINER_EVENT_TRAINING_COMPLETION:
+ g_mutex_lock (&trainer->training_completion_lock);
+ trainer->is_training_complete = TRUE;
+ GST_DEBUG ("send training_completion_cond signal");
+ g_cond_signal (&trainer->training_completion_cond);
+ g_mutex_unlock (&trainer->training_completion_lock);
+ break;
+ default:
+ break;
+ }
+}