trainer->is_training_complete = FALSE;
trainer->is_epoch_complete = FALSE;
trainer->cur_epoch_data_cnt = 0;
+ trainer->required_sample = 0;
gst_tensors_config_init (&trainer->in_config);
gst_tensors_config_init (&trainer->out_config);
g_cond_clear (&trainer->epoch_completion_cond);
g_mutex_clear (&trainer->epoch_completion_lock);
+ if (trainer->dummy_data_thread) {
+ g_thread_join (trainer->dummy_data_thread);
+ trainer->dummy_data_thread = NULL;
+ }
+
if (trainer->fw_created && trainer->fw) {
trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
}
}
/**
+ * @brief Dummy data generation thread
+ */
+static gpointer
+gst_tensor_trainer_dummy_data_generation_func (GstTensorTrainer * trainer)
+{
+ guint i;
+ gint ret = -1;
+ gpointer dummy_data[NNS_TENSOR_SIZE_LIMIT] = { NULL };
+ g_return_val_if_fail (trainer != NULL, NULL);
+
+ gst_tensor_trainer_stop_model_training (trainer);
+
+ for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+ dummy_data[i] = g_malloc (trainer->push_tensors[i].size);
+ memset (dummy_data[i], 1, trainer->push_tensors[i].size);
+ trainer->push_tensors[i].data = dummy_data[i];
+ }
+
+ do {
+ GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
+ trainer->cur_epoch_data_cnt);
+ GST_INFO_OBJECT (trainer, "num_tensors=%d",
+ trainer->prop.input_meta.num_tensors);
+
+ ret =
+ trainer->fw->push_data (trainer->fw, &trainer->prop,
+ trainer->privateData, trainer->push_tensors);
+
+ if (ret < 0) {
+ GST_ERROR_OBJECT (trainer, "Failed to push dummy data");
+ } else {
+ trainer->cur_epoch_data_cnt++;
+ }
+ } while (trainer->required_sample > trainer->cur_epoch_data_cnt);
+
+ for (i = 0; i < trainer->output_meta.num_tensors; i++)
+ g_free (dummy_data[i]);
+
+ return NULL;
+}
+
+
+/**
* @brief Change state of tensor_trainsink.
*/
static GstStateChangeReturn
switch (transition) {
case GST_STATE_CHANGE_PLAYING_TO_PAUSED:
GST_INFO_OBJECT (trainer, "PLAYING_TO_PAUSED");
- GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
- trainer->cur_epoch_data_cnt);
+ /* need to generate dummy data */
+ if (!trainer->is_training_complete) {
+ if (!g_strcmp0 (trainer->fw_name, "nntrainer")) {
+ GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
+ trainer->cur_epoch_data_cnt);
+ trainer->dummy_data_thread =
+ g_thread_new ("dumy_data_generation_func",
+ (GThreadFunc) gst_tensor_trainer_dummy_data_generation_func,
+ trainer);
+ }
+ }
break;
case GST_STATE_CHANGE_PAUSED_TO_READY:
static gboolean
gst_tensor_trainer_epochs_is_complete (GstTensorTrainer * trainer)
{
- guint required_sample;
-
g_return_val_if_fail (trainer != NULL, FALSE);
g_return_val_if_fail (trainer->fw != NULL, FALSE);
g_return_val_if_fail (&trainer->prop != NULL, FALSE);
- required_sample =
+ trainer->required_sample =
trainer->prop.num_training_samples + trainer->prop.num_validation_samples;
- if (trainer->cur_epoch_data_cnt != required_sample)
+ if (trainer->cur_epoch_data_cnt != trainer->required_sample)
return FALSE;
gst_tensor_trainer_wait_for_epoch_completion (trainer);
GstMemory *out_mem[NNS_TENSOR_SIZE_LIMIT] = { 0, };
GstMapInfo out_info[NNS_TENSOR_SIZE_LIMIT];
GstTensorMemory in_tensors[NNS_TENSOR_SIZE_LIMIT];
- GstTensorMemory push_tensors[NNS_TENSOR_SIZE_LIMIT];
GstTensorMemory out_tensors[NNS_TENSOR_SIZE_LIMIT];
GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
GstTensorMetaInfo out_meta[NNS_TENSOR_SIZE_LIMIT];
goto error;
}
/* Copy to data pointer */
- push_tensors[i] = in_tensors[i];
+ trainer->push_tensors[i] = in_tensors[i];
GST_INFO ("in_tensors[%u].size= %zd", i, in_tensors[i].size);
GST_INFO ("in_tensors[%u].data: %p", i, in_tensors[i].data);
- GST_INFO ("push_tensors[%u].size= %zd", i, push_tensors[i].size);
- GST_INFO ("push_tensors[%u].data: %p", i, push_tensors[i].data);
+ GST_INFO ("push_tensors[%u].size= %zd", i, trainer->push_tensors[i].size);
+ GST_INFO ("push_tensors[%u].data: %p", i, trainer->push_tensors[i].data);
}
ret =
trainer->fw->push_data (trainer->fw, &trainer->prop, trainer->privateData,
- push_tensors);
+ trainer->push_tensors);
trainer->cur_epoch_data_cnt++;
/* Free in info */
g_return_if_fail (trainer != NULL);
g_return_if_fail (trainer->fw != NULL);
g_return_if_fail (trainer->fw->stop != NULL);
- g_return_if_fail (trainer->ready_to_complete_training);
-
GST_DEBUG_OBJECT (trainer, "Stop model training");
ret = trainer->fw->stop (trainer->fw, &trainer->prop, &trainer->privateData);
if (ret != 0) {