PROP_FRAMEWORK,
PROP_MODEL_CONFIG,
PROP_MODEL_SAVE_PATH,
+ PROP_MODEL_LOAD_PATH,
PROP_NUM_INPUTS, /* number of input list */
PROP_NUM_LABELS, /* number of label list */
PROP_NUM_TRAINING_SAMPLES, /* number of training data */
* trainer, const GValue * value);
static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
const GValue * value);
+static void gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
+ const GValue * value);
static gboolean gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
const char *name);
static gboolean gst_tensor_trainer_create_framework (GstTensorTrainer *
/* Install properties for tensor_trainer */
g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
g_param_spec_string ("framework", "Framework",
- "Neural network framework to be used for model training",
+ "(not nullable) Neural network framework to be used for model training, ",
DEFAULT_STR_PROP_VALUE,
G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
G_PARAM_STATIC_STRINGS));
g_object_class_install_property (gobject_class, PROP_MODEL_CONFIG,
g_param_spec_string ("model-config", "Model configuration file path",
- "Model configuration file is used to configure the model "
+ "(not nullable) Model configuration file is used to configure the model "
"to be trained in neural network framework, set the file path",
DEFAULT_STR_PROP_VALUE,
G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH,
g_param_spec_string ("model-save-path", "Model save path",
- "Path to save the trained model in framework, if model-config "
+ "(not nullable) Path to save the trained model in framework, if model-config "
"contains information about the save file, it is ignored",
DEFAULT_STR_PROP_VALUE,
G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
G_PARAM_STATIC_STRINGS));
+ g_object_class_install_property (gobject_class, PROP_MODEL_LOAD_PATH,
+ g_param_spec_string ("model-load-path", "Model load path",
+ "(nullable) Path to load an existing model to use for training a new model",
+ DEFAULT_STR_PROP_VALUE,
+ G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+ G_PARAM_STATIC_STRINGS));
+
g_object_class_install_property (gobject_class, PROP_NUM_INPUTS,
g_param_spec_uint ("num-inputs", "Number of inputs",
"An input in a tensor can have one or more features data,"
/** init properties */
trainer->fw_name = g_strdup (DEFAULT_STR_PROP_VALUE);
- trainer->model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
- trainer->model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
- trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
- trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE);
+ trainer->prop.model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
+ trainer->prop.model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
+ trainer->prop.model_load_path = NULL;
trainer->prop.num_inputs = DEFAULT_PROP_INPUT_LIST;
trainer->prop.num_labels = DEFAULT_PROP_LABEL_LIST;
trainer->prop.num_training_samples = DEFAULT_PROP_TRAIN_SAMPLES;
trainer->prop.num_validation_samples = DEFAULT_PROP_VALID_SAMPLES;
trainer->prop.num_epochs = DEFAULT_PROP_EPOCHS;
+ trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
+ trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE);
trainer->fw = NULL;
trainer->fw_created = FALSE;
trainer = GST_TENSOR_TRAINER (object);
g_free (trainer->fw_name);
- g_free (trainer->model_config);
- g_free (trainer->model_save_path);
+ g_free ((char *) trainer->prop.model_config);
+ g_free ((char *) trainer->prop.model_save_path);
+ g_free ((char *) trainer->prop.model_load_path);
g_free (trainer->output_dimensions);
g_free (trainer->output_type);
case PROP_MODEL_SAVE_PATH:
gst_tensor_trainer_set_model_save_path (trainer, value);
break;
+ case PROP_MODEL_LOAD_PATH:
+ gst_tensor_trainer_set_model_load_path (trainer, value);
+ break;
case PROP_NUM_INPUTS:
trainer->prop.num_inputs = g_value_get_uint (value);
break;
g_value_set_string (value, trainer->fw_name);
break;
case PROP_MODEL_CONFIG:
- g_value_set_string (value, trainer->model_config);
+ g_value_set_string (value, trainer->prop.model_config);
break;
case PROP_MODEL_SAVE_PATH:
- g_value_set_string (value, trainer->model_save_path);
+ g_value_set_string (value, trainer->prop.model_save_path);
+ break;
+ case PROP_MODEL_LOAD_PATH:
+ g_value_set_string (value, trainer->prop.model_load_path);
break;
case PROP_NUM_INPUTS:
g_value_set_uint (value, trainer->prop.num_inputs);
g_return_val_if_fail (trainer != NULL, FALSE);
/* Parameters that can be retrieved from caps will be removed */
- if (!trainer->fw_name || !trainer->model_config || !trainer->model_save_path
+ if (!trainer->fw_name
+ || (g_ascii_strcasecmp (trainer->prop.model_config,
+ DEFAULT_STR_PROP_VALUE) == 0)
+ || (g_ascii_strcasecmp (trainer->prop.model_save_path,
+ DEFAULT_STR_PROP_VALUE) == 0)
|| trainer->prop.num_epochs <= 0 || trainer->prop.num_inputs <= 0
|| trainer->prop.num_labels <= 0) {
GST_ERROR_OBJECT (trainer, "Check for invalid param value");
gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer *
trainer, const GValue * value)
{
- g_free (trainer->model_config);
- trainer->model_config = g_value_dup_string (value);
- trainer->prop.model_config = trainer->model_config;
+ g_free ((char *) trainer->prop.model_config);
+ trainer->prop.model_config = g_value_dup_string (value);
GST_INFO_OBJECT (trainer, "Model configuration file path: %s",
trainer->prop.model_config);
}
gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
const GValue * value)
{
- g_free (trainer->model_save_path);
- trainer->model_save_path = g_value_dup_string (value);
- trainer->prop.model_save_path = trainer->model_save_path;
+ g_free ((char *) trainer->prop.model_save_path);
+ trainer->prop.model_save_path = g_value_dup_string (value);
GST_INFO_OBJECT (trainer, "File path to save the model: %s",
trainer->prop.model_save_path);
}
/**
+ * @brief Handle "PROP_MODEL_LOAD_PATH" for set-property
+ */
+static void
+gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
+ const GValue * value)
+{
+ g_free ((char *) trainer->prop.model_load_path);
+ trainer->prop.model_load_path = g_value_dup_string (value);
+ GST_INFO_OBJECT (trainer, "File path to load the model: %s",
+ trainer->prop.model_load_path);
+}
+
+/**
* @brief Find Trainer sub-plugin with the name.
*/
static gboolean
"gst-launch-1.0 datareposrc location=%s json=%s "
"start-sample-index=3 stop-sample-index=202 tensors-sequence=0,1 epochs=1 ! "
"tensor_trainer name=tensor_trainer framework=nntrainer model-config=%s "
- "model-save-path=model.bin num-inputs=1 num-labels=1 "
+ "model-save-path=new_model.bin model-load-path=old_model.bin num-inputs=1 num-labels=1 "
"num-training-samples=100 num-validation-samples=100 epochs=1 ! "
"tensor_sink",
file_path, json_path, model_config_path);
g_object_get (tensor_trainer, "model-config", &get_str, NULL);
EXPECT_STREQ (get_str, model_config_path);
+ g_free (get_str);
g_object_get (tensor_trainer, "model-save-path", &get_str, NULL);
- EXPECT_STREQ (get_str, "model.bin");
+ EXPECT_STREQ (get_str, "new_model.bin");
+ g_free (get_str);
+
+ g_object_get (tensor_trainer, "model-load-path", &get_str, NULL);
+ EXPECT_STREQ (get_str, "old_model.bin");
+ g_free (get_str);
+
+ /* set nullable param */
+ g_object_set (GST_OBJECT (tensor_trainer), "model-load-path", NULL, NULL);
+
+ g_object_get (tensor_trainer, "model-load-path", &get_str, NULL);
+ EXPECT_STREQ (get_str, NULL);
+ g_free (get_str);
g_object_get (tensor_trainer, "num-inputs", &get_value, NULL);
ASSERT_EQ (get_value, 1U);