From 8e7589e7d1cad8c0638e7fbb10478eb2f1397c51 Mon Sep 17 00:00:00 2001 From: hyunil park Date: Thu, 27 Jul 2023 15:54:04 +0900 Subject: [PATCH] [tensor_trainer] Add load_model_path to GstTensorTrainerProperties - Add model_load_path to load an existing model to use for training a new model - Add model_load_path property to tensor_trainer Signed-off-by: hyunil park --- gst/nnstreamer/elements/gsttensor_trainer.c | 69 ++++++++++++++++------ gst/nnstreamer/elements/gsttensor_trainer.h | 2 - .../include/nnstreamer_plugin_api_trainer.h | 3 +- tests/nnstreamer_trainer/unittest_trainer.cc | 17 +++++- 4 files changed, 68 insertions(+), 23 deletions(-) diff --git a/gst/nnstreamer/elements/gsttensor_trainer.c b/gst/nnstreamer/elements/gsttensor_trainer.c index f6e29fd..0879c53 100644 --- a/gst/nnstreamer/elements/gsttensor_trainer.c +++ b/gst/nnstreamer/elements/gsttensor_trainer.c @@ -95,6 +95,7 @@ enum PROP_FRAMEWORK, PROP_MODEL_CONFIG, PROP_MODEL_SAVE_PATH, + PROP_MODEL_LOAD_PATH, PROP_NUM_INPUTS, /* number of input list */ PROP_NUM_LABELS, /* number of label list */ PROP_NUM_TRAINING_SAMPLES, /* number of training data */ @@ -128,6 +129,8 @@ static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer * trainer, const GValue * value); static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer, const GValue * value); +static void gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer, + const GValue * value); static gboolean gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name); static gboolean gst_tensor_trainer_create_framework (GstTensorTrainer * @@ -169,14 +172,14 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass) /* Install properties for tensor_trainer */ g_object_class_install_property (gobject_class, PROP_FRAMEWORK, g_param_spec_string ("framework", "Framework", - "Neural network framework to be used for model training", + "(not nullable) Neural network framework to be used for model training, ", DEFAULT_STR_PROP_VALUE, G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY | G_PARAM_STATIC_STRINGS)); g_object_class_install_property (gobject_class, PROP_MODEL_CONFIG, g_param_spec_string ("model-config", "Model configuration file path", - "Model configuration file is used to configure the model " + "(not nullable) Model configuration file is used to configure the model " "to be trained in neural network framework, set the file path", DEFAULT_STR_PROP_VALUE, G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY | @@ -184,12 +187,19 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass) g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH, g_param_spec_string ("model-save-path", "Model save path", - "Path to save the trained model in framework, if model-config " + "(not nullable) Path to save the trained model in framework, if model-config " "contains information about the save file, it is ignored", DEFAULT_STR_PROP_VALUE, G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY | G_PARAM_STATIC_STRINGS)); + g_object_class_install_property (gobject_class, PROP_MODEL_LOAD_PATH, + g_param_spec_string ("model-load-path", "Model load path", + "(nullable) Path to load an existing model to use for training a new model", + DEFAULT_STR_PROP_VALUE, + G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY | + G_PARAM_STATIC_STRINGS)); + g_object_class_install_property (gobject_class, PROP_NUM_INPUTS, g_param_spec_uint ("num-inputs", "Number of inputs", "An input in a tensor can have one or more features data," @@ -268,15 +278,16 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer) /** init properties */ trainer->fw_name = g_strdup (DEFAULT_STR_PROP_VALUE); - trainer->model_config = g_strdup (DEFAULT_STR_PROP_VALUE); - trainer->model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE); - trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE); - trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE); + trainer->prop.model_config = g_strdup (DEFAULT_STR_PROP_VALUE); + trainer->prop.model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE); + trainer->prop.model_load_path = NULL; trainer->prop.num_inputs = DEFAULT_PROP_INPUT_LIST; trainer->prop.num_labels = DEFAULT_PROP_LABEL_LIST; trainer->prop.num_training_samples = DEFAULT_PROP_TRAIN_SAMPLES; trainer->prop.num_validation_samples = DEFAULT_PROP_VALID_SAMPLES; trainer->prop.num_epochs = DEFAULT_PROP_EPOCHS; + trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE); + trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE); trainer->fw = NULL; trainer->fw_created = FALSE; @@ -310,8 +321,9 @@ gst_tensor_trainer_finalize (GObject * object) trainer = GST_TENSOR_TRAINER (object); g_free (trainer->fw_name); - g_free (trainer->model_config); - g_free (trainer->model_save_path); + g_free ((char *) trainer->prop.model_config); + g_free ((char *) trainer->prop.model_save_path); + g_free ((char *) trainer->prop.model_load_path); g_free (trainer->output_dimensions); g_free (trainer->output_type); @@ -351,6 +363,9 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id, case PROP_MODEL_SAVE_PATH: gst_tensor_trainer_set_model_save_path (trainer, value); break; + case PROP_MODEL_LOAD_PATH: + gst_tensor_trainer_set_model_load_path (trainer, value); + break; case PROP_NUM_INPUTS: trainer->prop.num_inputs = g_value_get_uint (value); break; @@ -388,10 +403,13 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id, g_value_set_string (value, trainer->fw_name); break; case PROP_MODEL_CONFIG: - g_value_set_string (value, trainer->model_config); + g_value_set_string (value, trainer->prop.model_config); break; case PROP_MODEL_SAVE_PATH: - g_value_set_string (value, trainer->model_save_path); + g_value_set_string (value, trainer->prop.model_save_path); + break; + case PROP_MODEL_LOAD_PATH: + g_value_set_string (value, trainer->prop.model_load_path); break; case PROP_NUM_INPUTS: g_value_set_uint (value, trainer->prop.num_inputs); @@ -423,7 +441,11 @@ gst_tensor_trainer_check_invalid_param (GstTensorTrainer * trainer) g_return_val_if_fail (trainer != NULL, FALSE); /* Parameters that can be retrieved from caps will be removed */ - if (!trainer->fw_name || !trainer->model_config || !trainer->model_save_path + if (!trainer->fw_name + || (g_ascii_strcasecmp (trainer->prop.model_config, + DEFAULT_STR_PROP_VALUE) == 0) + || (g_ascii_strcasecmp (trainer->prop.model_save_path, + DEFAULT_STR_PROP_VALUE) == 0) || trainer->prop.num_epochs <= 0 || trainer->prop.num_inputs <= 0 || trainer->prop.num_labels <= 0) { GST_ERROR_OBJECT (trainer, "Check for invalid param value"); @@ -1020,9 +1042,8 @@ static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer * trainer, const GValue * value) { - g_free (trainer->model_config); - trainer->model_config = g_value_dup_string (value); - trainer->prop.model_config = trainer->model_config; + g_free ((char *) trainer->prop.model_config); + trainer->prop.model_config = g_value_dup_string (value); GST_INFO_OBJECT (trainer, "Model configuration file path: %s", trainer->prop.model_config); } @@ -1034,14 +1055,26 @@ static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer, const GValue * value) { - g_free (trainer->model_save_path); - trainer->model_save_path = g_value_dup_string (value); - trainer->prop.model_save_path = trainer->model_save_path; + g_free ((char *) trainer->prop.model_save_path); + trainer->prop.model_save_path = g_value_dup_string (value); GST_INFO_OBJECT (trainer, "File path to save the model: %s", trainer->prop.model_save_path); } /** + * @brief Handle "PROP_MODEL_LOAD_PATH" for set-property + */ +static void +gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer, + const GValue * value) +{ + g_free ((char *) trainer->prop.model_load_path); + trainer->prop.model_load_path = g_value_dup_string (value); + GST_INFO_OBJECT (trainer, "File path to load the model: %s", + trainer->prop.model_load_path); +} + +/** * @brief Find Trainer sub-plugin with the name. */ static gboolean diff --git a/gst/nnstreamer/elements/gsttensor_trainer.h b/gst/nnstreamer/elements/gsttensor_trainer.h index 9c43233..33777b1 100644 --- a/gst/nnstreamer/elements/gsttensor_trainer.h +++ b/gst/nnstreamer/elements/gsttensor_trainer.h @@ -48,8 +48,6 @@ struct _GstTensorTrainer GstPad *srcpad; gchar *fw_name; - gchar *model_config; - gchar *model_save_path; gchar *output_dimensions; gchar *output_type; diff --git a/gst/nnstreamer/include/nnstreamer_plugin_api_trainer.h b/gst/nnstreamer/include/nnstreamer_plugin_api_trainer.h index de43584..d48f712 100644 --- a/gst/nnstreamer/include/nnstreamer_plugin_api_trainer.h +++ b/gst/nnstreamer/include/nnstreamer_plugin_api_trainer.h @@ -32,7 +32,8 @@ typedef struct _GstTensorTrainerProperties { GstTensorsInfo input_meta; /**< configured input tensor info */ const char *model_config; /**< The configuration file path for creating model */ - const char *model_save_path; /**< The file path to save the model */ + const char *model_save_path; /**< The file path to save the new model */ + const char *model_load_path; /**< The file path to load an existing model to use for training a new model */ unsigned int num_inputs; /**< The number of input lists, the input is where framework receive the features to train the model, num_inputs indicates how many inputs there are. */ unsigned int num_labels; /**< The number of label lists, the label is where framework receive the class to train the model, num_labels indicates how many labels there are. */ unsigned int num_training_samples; /**< The number of training sample used to train the model. */ diff --git a/tests/nnstreamer_trainer/unittest_trainer.cc b/tests/nnstreamer_trainer/unittest_trainer.cc index 4564df3..206ab6b 100644 --- a/tests/nnstreamer_trainer/unittest_trainer.cc +++ b/tests/nnstreamer_trainer/unittest_trainer.cc @@ -73,7 +73,7 @@ TEST (tensor_trainer, SetParams) "gst-launch-1.0 datareposrc location=%s json=%s " "start-sample-index=3 stop-sample-index=202 tensors-sequence=0,1 epochs=1 ! " "tensor_trainer name=tensor_trainer framework=nntrainer model-config=%s " - "model-save-path=model.bin num-inputs=1 num-labels=1 " + "model-save-path=new_model.bin model-load-path=old_model.bin num-inputs=1 num-labels=1 " "num-training-samples=100 num-validation-samples=100 epochs=1 ! " "tensor_sink", file_path, json_path, model_config_path); @@ -89,9 +89,22 @@ TEST (tensor_trainer, SetParams) g_object_get (tensor_trainer, "model-config", &get_str, NULL); EXPECT_STREQ (get_str, model_config_path); + g_free (get_str); g_object_get (tensor_trainer, "model-save-path", &get_str, NULL); - EXPECT_STREQ (get_str, "model.bin"); + EXPECT_STREQ (get_str, "new_model.bin"); + g_free (get_str); + + g_object_get (tensor_trainer, "model-load-path", &get_str, NULL); + EXPECT_STREQ (get_str, "old_model.bin"); + g_free (get_str); + + /* set nullable param */ + g_object_set (GST_OBJECT (tensor_trainer), "model-load-path", NULL, NULL); + + g_object_get (tensor_trainer, "model-load-path", &get_str, NULL); + EXPECT_STREQ (get_str, NULL); + g_free (get_str); g_object_get (tensor_trainer, "num-inputs", &get_value, NULL); ASSERT_EQ (get_value, 1U); -- 2.7.4