[tensor_trainer] Add load_model_path to GstTensorTrainerProperties
authorhyunil park <hyunil46.park@samsung.com>
Thu, 27 Jul 2023 06:54:04 +0000 (15:54 +0900)
committerjaeyun-jung <39614140+jaeyun-jung@users.noreply.github.com>
Thu, 3 Aug 2023 04:44:23 +0000 (13:44 +0900)
- Add model_load_path to load an existing model to use for training a new model
- Add model_load_path property to tensor_trainer

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/elements/gsttensor_trainer.c
gst/nnstreamer/elements/gsttensor_trainer.h
gst/nnstreamer/include/nnstreamer_plugin_api_trainer.h
tests/nnstreamer_trainer/unittest_trainer.cc

index f6e29fd..0879c53 100644 (file)
@@ -95,6 +95,7 @@ enum
   PROP_FRAMEWORK,
   PROP_MODEL_CONFIG,
   PROP_MODEL_SAVE_PATH,
+  PROP_MODEL_LOAD_PATH,
   PROP_NUM_INPUTS,              /* number of input list */
   PROP_NUM_LABELS,              /* number of label list */
   PROP_NUM_TRAINING_SAMPLES,    /* number of training data */
@@ -128,6 +129,8 @@ static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer
     * trainer, const GValue * value);
 static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
     const GValue * value);
+static void gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
+    const GValue * value);
 static gboolean gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
     const char *name);
 static gboolean gst_tensor_trainer_create_framework (GstTensorTrainer *
@@ -169,14 +172,14 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
   /* Install properties for tensor_trainer */
   g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
       g_param_spec_string ("framework", "Framework",
-          "Neural network framework to be used for model training",
+          "(not nullable) Neural network framework to be used for model training, ",
           DEFAULT_STR_PROP_VALUE,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
   g_object_class_install_property (gobject_class, PROP_MODEL_CONFIG,
       g_param_spec_string ("model-config", "Model configuration file path",
-          "Model configuration file is used to configure the model "
+          "(not nullable) Model configuration file is used to configure the model "
           "to be trained in neural network framework, set the file path",
           DEFAULT_STR_PROP_VALUE,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
@@ -184,12 +187,19 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
 
   g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH,
       g_param_spec_string ("model-save-path", "Model save path",
-          "Path to save the trained model in framework, if model-config "
+          "(not nullable) Path to save the trained model in framework, if model-config "
           "contains information about the save file, it is ignored",
           DEFAULT_STR_PROP_VALUE,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
+  g_object_class_install_property (gobject_class, PROP_MODEL_LOAD_PATH,
+      g_param_spec_string ("model-load-path", "Model load path",
+          "(nullable) Path to load an existing model to use for training a new model",
+          DEFAULT_STR_PROP_VALUE,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
   g_object_class_install_property (gobject_class, PROP_NUM_INPUTS,
       g_param_spec_uint ("num-inputs", "Number of inputs",
           "An input in a tensor can have one or more features data,"
@@ -268,15 +278,16 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer)
 
   /** init properties */
   trainer->fw_name = g_strdup (DEFAULT_STR_PROP_VALUE);
-  trainer->model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
-  trainer->model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
-  trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
-  trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->prop.model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->prop.model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->prop.model_load_path = NULL;
   trainer->prop.num_inputs = DEFAULT_PROP_INPUT_LIST;
   trainer->prop.num_labels = DEFAULT_PROP_LABEL_LIST;
   trainer->prop.num_training_samples = DEFAULT_PROP_TRAIN_SAMPLES;
   trainer->prop.num_validation_samples = DEFAULT_PROP_VALID_SAMPLES;
   trainer->prop.num_epochs = DEFAULT_PROP_EPOCHS;
+  trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE);
 
   trainer->fw = NULL;
   trainer->fw_created = FALSE;
@@ -310,8 +321,9 @@ gst_tensor_trainer_finalize (GObject * object)
   trainer = GST_TENSOR_TRAINER (object);
 
   g_free (trainer->fw_name);
-  g_free (trainer->model_config);
-  g_free (trainer->model_save_path);
+  g_free ((char *) trainer->prop.model_config);
+  g_free ((char *) trainer->prop.model_save_path);
+  g_free ((char *) trainer->prop.model_load_path);
   g_free (trainer->output_dimensions);
   g_free (trainer->output_type);
 
@@ -351,6 +363,9 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id,
     case PROP_MODEL_SAVE_PATH:
       gst_tensor_trainer_set_model_save_path (trainer, value);
       break;
+    case PROP_MODEL_LOAD_PATH:
+      gst_tensor_trainer_set_model_load_path (trainer, value);
+      break;
     case PROP_NUM_INPUTS:
       trainer->prop.num_inputs = g_value_get_uint (value);
       break;
@@ -388,10 +403,13 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id,
       g_value_set_string (value, trainer->fw_name);
       break;
     case PROP_MODEL_CONFIG:
-      g_value_set_string (value, trainer->model_config);
+      g_value_set_string (value, trainer->prop.model_config);
       break;
     case PROP_MODEL_SAVE_PATH:
-      g_value_set_string (value, trainer->model_save_path);
+      g_value_set_string (value, trainer->prop.model_save_path);
+      break;
+    case PROP_MODEL_LOAD_PATH:
+      g_value_set_string (value, trainer->prop.model_load_path);
       break;
     case PROP_NUM_INPUTS:
       g_value_set_uint (value, trainer->prop.num_inputs);
@@ -423,7 +441,11 @@ gst_tensor_trainer_check_invalid_param (GstTensorTrainer * trainer)
   g_return_val_if_fail (trainer != NULL, FALSE);
 
   /* Parameters that can be retrieved from caps will be removed */
-  if (!trainer->fw_name || !trainer->model_config || !trainer->model_save_path
+  if (!trainer->fw_name
+      || (g_ascii_strcasecmp (trainer->prop.model_config,
+              DEFAULT_STR_PROP_VALUE) == 0)
+      || (g_ascii_strcasecmp (trainer->prop.model_save_path,
+              DEFAULT_STR_PROP_VALUE) == 0)
       || trainer->prop.num_epochs <= 0 || trainer->prop.num_inputs <= 0
       || trainer->prop.num_labels <= 0) {
     GST_ERROR_OBJECT (trainer, "Check for invalid param value");
@@ -1020,9 +1042,8 @@ static void
 gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer *
     trainer, const GValue * value)
 {
-  g_free (trainer->model_config);
-  trainer->model_config = g_value_dup_string (value);
-  trainer->prop.model_config = trainer->model_config;
+  g_free ((char *) trainer->prop.model_config);
+  trainer->prop.model_config = g_value_dup_string (value);
   GST_INFO_OBJECT (trainer, "Model configuration file path: %s",
       trainer->prop.model_config);
 }
@@ -1034,14 +1055,26 @@ static void
 gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
     const GValue * value)
 {
-  g_free (trainer->model_save_path);
-  trainer->model_save_path = g_value_dup_string (value);
-  trainer->prop.model_save_path = trainer->model_save_path;
+  g_free ((char *) trainer->prop.model_save_path);
+  trainer->prop.model_save_path = g_value_dup_string (value);
   GST_INFO_OBJECT (trainer, "File path to save the model: %s",
       trainer->prop.model_save_path);
 }
 
 /**
+ * @brief Handle "PROP_MODEL_LOAD_PATH" for set-property
+ */
+static void
+gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
+    const GValue * value)
+{
+  g_free ((char *) trainer->prop.model_load_path);
+  trainer->prop.model_load_path = g_value_dup_string (value);
+  GST_INFO_OBJECT (trainer, "File path to load the model: %s",
+      trainer->prop.model_load_path);
+}
+
+/**
  * @brief Find Trainer sub-plugin with the name.
  */
 static gboolean
index 9c43233..33777b1 100644 (file)
@@ -48,8 +48,6 @@ struct _GstTensorTrainer
   GstPad *srcpad;
 
   gchar *fw_name;
-  gchar *model_config;
-  gchar *model_save_path;
   gchar *output_dimensions;
   gchar *output_type;
 
index de43584..d48f712 100644 (file)
@@ -32,7 +32,8 @@ typedef struct _GstTensorTrainerProperties
 {
   GstTensorsInfo input_meta;    /**< configured input tensor info */
   const char *model_config;    /**< The configuration file path for creating model */
-  const char *model_save_path;    /**< The file path to save the model */
+  const char *model_save_path;    /**< The file path to save the new model */
+  const char *model_load_path;    /**< The file path to load an existing model to use for training a new model */
   unsigned int num_inputs;    /**< The number of input lists, the input is where framework receive the features to train the model, num_inputs indicates how many inputs there are. */
   unsigned int num_labels;    /**< The number of label lists, the label is where framework receive the class to train the model, num_labels indicates how many labels there are. */
   unsigned int num_training_samples;    /**< The number of training sample used to train the model. */
index 4564df3..206ab6b 100644 (file)
@@ -73,7 +73,7 @@ TEST (tensor_trainer, SetParams)
       "gst-launch-1.0 datareposrc location=%s json=%s "
       "start-sample-index=3 stop-sample-index=202 tensors-sequence=0,1 epochs=1 ! "
       "tensor_trainer name=tensor_trainer framework=nntrainer model-config=%s "
-      "model-save-path=model.bin num-inputs=1 num-labels=1 "
+      "model-save-path=new_model.bin model-load-path=old_model.bin num-inputs=1 num-labels=1 "
       "num-training-samples=100 num-validation-samples=100 epochs=1 ! "
       "tensor_sink",
       file_path, json_path, model_config_path);
@@ -89,9 +89,22 @@ TEST (tensor_trainer, SetParams)
 
   g_object_get (tensor_trainer, "model-config", &get_str, NULL);
   EXPECT_STREQ (get_str, model_config_path);
+  g_free (get_str);
 
   g_object_get (tensor_trainer, "model-save-path", &get_str, NULL);
-  EXPECT_STREQ (get_str, "model.bin");
+  EXPECT_STREQ (get_str, "new_model.bin");
+  g_free (get_str);
+
+  g_object_get (tensor_trainer, "model-load-path", &get_str, NULL);
+  EXPECT_STREQ (get_str, "old_model.bin");
+  g_free (get_str);
+
+  /* set nullable param */
+  g_object_set (GST_OBJECT (tensor_trainer), "model-load-path", NULL, NULL);
+
+  g_object_get (tensor_trainer, "model-load-path", &get_str, NULL);
+  EXPECT_STREQ (get_str, NULL);
+  g_free (get_str);
 
   g_object_get (tensor_trainer, "num-inputs", &get_value, NULL);
   ASSERT_EQ (get_value, 1U);