[API][tensor_trainer] Change type of structure member variable
authorhyunil park <hyunil46.park@samsung.com>
Mon, 17 Apr 2023 05:42:59 +0000 (14:42 +0900)
committerSangjung Woo <again4you@gmail.com>
Fri, 21 Apr 2023 02:30:27 +0000 (11:30 +0900)
- Change from int64_t to unsigned int
- Bug-fix: When passing values to a sub-plugin in an arm 32bit environment, invalid values are passed.
- Add function to Check invalid param
- Remove default framework
- Change some default value

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/elements/gsttensor_trainer.c
gst/nnstreamer/include/nnstreamer_plugin_api_trainer.h

index ef5f1d0..9c43624 100644 (file)
@@ -11,7 +11,7 @@
  *
  * ## Example launch line
  * |[
- * gst-launch-1.0 repo_src location=mnist_trainingSet.dat ! \
+ * gst-launch-1.0 datareposrc location=mnist_trainingSet.dat json=mnist.json start-sample-index=3 stop-sample-index=202 epochs=5 ! \
  * other/tensors, format=static, num_tensors=2, framerate=0/1, dimensions=1:1:784:1.1:1:10:1, types=float32.float32 ! \
  * tensor_trainer framework=nntrainer model-config=mnist.ini model-save-path=model.bin input-dim=1:1:784:1,1:1:10:1 \
  * input-type=float32,float32 num-inputs=1 num-labels=1 num-training-samples=100 num-validation-samples=100 epochs=5 ! \
@@ -59,7 +59,6 @@ G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT);
 /**
  * @brief Default framework property value
  */
-#define DEFAULT_PROP_FRAMEWORK "nntrainer"
 #define DEFAULT_PROP_INPUT_LIST 1
 #define DEFAULT_PROP_LABEL_LIST 1
 #define DEFAULT_PROP_TRAIN_SAMPLES 0
@@ -117,12 +116,13 @@ static void gst_tensor_trainer_set_prop_input_dimension (GstTensorTrainer *
     trainer, const GValue * value);
 static void gst_tensor_trainer_set_prop_input_type (GstTensorTrainer * trainer,
     const GValue * value);
-static void gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
+static gboolean gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
     const char *name);
-static void gst_tensor_trainer_create_framework (GstTensorTrainer * trainer);
+static gboolean gst_tensor_trainer_create_framework (GstTensorTrainer *
+    trainer);
 static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
     guint index, gboolean is_input);
-static void gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
+static gboolean gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
 static void gst_tensor_trainer_train_model (GstTensorTrainer * trainer);
 static void gst_tensor_trainer_output_dimension (GstTensorTrainer * trainer);
 static void gst_tensor_trainer_output_type (GstTensorTrainer * trainer);
@@ -156,7 +156,7 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
   g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
       g_param_spec_string ("framework", "Framework",
           "Neural network framework to be used for model training",
-          DEFAULT_PROP_FRAMEWORK,
+          DEFAULT_STR_PROP_VALUE,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
@@ -192,41 +192,41 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
 
 
   g_object_class_install_property (gobject_class, PROP_NUM_INPUTS,
-      g_param_spec_int64 ("num-inputs", "Number of inputs",
+      g_param_spec_uint ("num-inputs", "Number of inputs",
           "An input in a tensor can have one or more features data,"
           "set how many inputs are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
   g_object_class_install_property (gobject_class, PROP_NUM_LABELS,
-      g_param_spec_int64 ("num-labels", "Number of labels",
+      g_param_spec_uint ("num-labels", "Number of labels",
           "A label in a tensor can have one or more classes data,"
           "set how many labels are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
   g_object_class_install_property (gobject_class, PROP_NUM_TRAINING_SAMPLES,
-      g_param_spec_int64 ("num-training-samples", "Number of training samples",
+      g_param_spec_uint ("num-training-samples", "Number of training samples",
           "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
           ", set how many samples are taken for training model",
-          0, G_MAXINT64, 1,
+          0, G_MAXINT, 0,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
   g_object_class_install_property (gobject_class, PROP_NUM_VALIDATION_SAMPLES,
-      g_param_spec_int64 ("num-validation-samples",
+      g_param_spec_uint ("num-validation-samples",
           "Number of validation samples",
           "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
           ", set how many samples are taken for validation model",
-          0, G_MAXINT64, 1,
+          0, G_MAXINT, 0,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
   g_object_class_install_property (gobject_class, PROP_EPOCHS,
-      g_param_spec_int64 ("epochs", "Number of epoch",
+      g_param_spec_uint ("epochs", "Number of epoch",
           "Epochs are repetitions of training samples and validation smaples, "
           "number of samples received for model training is "
-          "(num-training-samples+num-validation-samples)*epochs", 0, G_MAXINT64,
+          "(num-training-samples+num-validation-samples)*epochs", 0, G_MAXINT,
           DEFAULT_PROP_EPOCHS,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
@@ -268,7 +268,7 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer)
   gst_element_add_pad (GST_ELEMENT (trainer), trainer->srcpad);
 
   /** init properties */
-  trainer->fw_name = g_strdup (DEFAULT_PROP_FRAMEWORK);
+  trainer->fw_name = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->input_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
@@ -352,19 +352,19 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id,
       gst_tensor_trainer_set_prop_input_type (trainer, value);
       break;
     case PROP_NUM_INPUTS:
-      trainer->prop.num_inputs = g_value_get_int64 (value);
+      trainer->prop.num_inputs = g_value_get_uint (value);
       break;
     case PROP_NUM_LABELS:
-      trainer->prop.num_labels = g_value_get_int64 (value);
+      trainer->prop.num_labels = g_value_get_uint (value);
       break;
     case PROP_NUM_TRAINING_SAMPLES:
-      trainer->prop.num_training_samples = g_value_get_int64 (value);
+      trainer->prop.num_training_samples = g_value_get_uint (value);
       break;
     case PROP_NUM_VALIDATION_SAMPLES:
-      trainer->prop.num_validation_samples = g_value_get_int64 (value);
+      trainer->prop.num_validation_samples = g_value_get_uint (value);
       break;
     case PROP_EPOCHS:
-      trainer->prop.num_epochs = g_value_get_int64 (value);
+      trainer->prop.num_epochs = g_value_get_uint (value);
       break;
     default:
       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
@@ -400,19 +400,19 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id,
       g_value_set_string (value, trainer->input_type);
       break;
     case PROP_NUM_INPUTS:
-      g_value_set_int64 (value, trainer->prop.num_inputs);
+      g_value_set_uint (value, trainer->prop.num_inputs);
       break;
     case PROP_NUM_LABELS:
-      g_value_set_int64 (value, trainer->prop.num_labels);
+      g_value_set_uint (value, trainer->prop.num_labels);
       break;
     case PROP_NUM_TRAINING_SAMPLES:
-      g_value_set_int64 (value, trainer->prop.num_training_samples);
+      g_value_set_uint (value, trainer->prop.num_training_samples);
       break;
     case PROP_NUM_VALIDATION_SAMPLES:
-      g_value_set_int64 (value, trainer->prop.num_validation_samples);
+      g_value_set_uint (value, trainer->prop.num_validation_samples);
       break;
     case PROP_EPOCHS:
-      g_value_set_int64 (value, trainer->prop.num_epochs);
+      g_value_set_uint (value, trainer->prop.num_epochs);
       break;
     default:
       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
@@ -421,6 +421,27 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id,
 }
 
 /**
+ * @brief Check invalid param
+ */
+static gboolean
+gst_tensor_trainer_check_invalid_param (GstTensorTrainer * trainer)
+{
+  g_return_val_if_fail (trainer != NULL, FALSE);
+
+  /* Parameters that can be retrieved from caps will be removed */
+  if (!trainer->fw_name || !trainer->model_config || !trainer->model_save_path
+      || !trainer->input_dimensions || !trainer->input_type
+      || trainer->prop.num_epochs <= 0 || trainer->prop.num_inputs <= 0
+      || trainer->prop.num_labels <= 0) {
+    GST_ERROR_OBJECT (trainer, "Check for invalid param value");
+
+    return FALSE;
+  }
+
+  return TRUE;
+}
+
+/**
  * @brief Change state of tensor_trainsink.
  */
 static GstStateChangeReturn
@@ -433,11 +454,17 @@ gst_tensor_trainer_change_state (GstElement * element,
   switch (transition) {
     case GST_STATE_CHANGE_NULL_TO_READY:
       GST_INFO_OBJECT (trainer, "NULL_TO_READY");
+
+      if (!gst_tensor_trainer_check_invalid_param (trainer))
+        goto state_change_failed;
+
       break;
 
     case GST_STATE_CHANGE_READY_TO_PAUSED:
       GST_INFO_OBJECT (trainer, "READY_TO_PAUSED");
-      gst_tensor_trainer_create_model (trainer);
+      if (!gst_tensor_trainer_create_model (trainer))
+        goto state_change_failed;
+
       break;
 
     case GST_STATE_CHANGE_PAUSED_TO_PLAYING:
@@ -471,6 +498,11 @@ gst_tensor_trainer_change_state (GstElement * element,
   }
 
   return ret;
+
+state_change_failed:
+  GST_ERROR_OBJECT (trainer, "state change failed");
+
+  return GST_STATE_CHANGE_FAILURE;
 }
 
 /**
@@ -999,42 +1031,45 @@ gst_tensor_trainer_set_prop_input_type (GstTensorTrainer * trainer,
 /**
  * @brief Find Trainer sub-plugin with the name.
  */
-static void
+static gboolean
 gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name)
 {
   const GstTensorTrainerFramework *fw = NULL;
 
-  g_return_if_fail (name != NULL);
-  g_return_if_fail (trainer != NULL);
+  g_return_val_if_fail (name != NULL, FALSE);
+  g_return_val_if_fail (trainer != NULL, FALSE);
 
   GST_INFO_OBJECT (trainer, "Try to find framework: %s", name);
 
   fw = get_subplugin (NNS_SUBPLUGIN_TRAINER, name);
-  if (fw) {
-    GST_INFO_OBJECT (trainer, "Find framework %s:%p", trainer->fw_name, fw);
-    trainer->fw = fw;
-  } else {
+  if (!fw) {
     GST_ERROR_OBJECT (trainer, "Can not find framework(%s)", trainer->fw_name);
+    return FALSE;
   }
+
+  GST_INFO_OBJECT (trainer, "Find framework %s:%p", trainer->fw_name, fw);
+  trainer->fw = fw;
+
+  return TRUE;
 }
 
 /**
  * @brief Create NN framework.
  */
-static void
+static gboolean
 gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
 {
-  g_return_if_fail (trainer != NULL);
+  g_return_val_if_fail (trainer != NULL, FALSE);
 
   if (!trainer->fw || trainer->fw_created) {
     GST_ERROR_OBJECT (trainer, "fw is not opened(%d) or fw is not null(%p)",
         trainer->fw_created, trainer->fw);
-    return;
+    return FALSE;
   }
 
   if (!trainer->fw->create) {
     GST_ERROR_OBJECT (trainer, "Could not create framework");
-    return;
+    return FALSE;
   }
 
   GST_DEBUG_OBJECT (trainer, "%p", trainer->privateData);
@@ -1042,6 +1077,8 @@ gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
           &trainer->privateData) >= 0)
     trainer->fw_created = TRUE;
   GST_DEBUG_OBJECT (trainer, "Success, Framework: %p", trainer->privateData);
+
+  return TRUE;
 }
 
 /**
@@ -1070,17 +1107,24 @@ gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
 /**
  * @brief Create model
  */
-static void
+static gboolean
 gst_tensor_trainer_create_model (GstTensorTrainer * trainer)
 {
-  g_return_if_fail (trainer != NULL);
+  gboolean ret = TRUE;
+
+  g_return_val_if_fail (trainer != NULL, FALSE);
+  g_return_val_if_fail (trainer->fw_name != NULL, FALSE);
+
+  ret = gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
+  if (!ret)
+    return ret;
 
-  if (trainer->fw_name)
-    gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
   if (trainer->fw) {
     /* model create and compile */
-    gst_tensor_trainer_create_framework (trainer);
+    ret = gst_tensor_trainer_create_framework (trainer);
   }
+
+  return ret;
 }
 
 /**
index bd0d2a6..ec60f5c 100644 (file)
@@ -33,11 +33,11 @@ typedef struct _GstTensorTrainerProperties
   GstTensorsInfo input_meta;    /**< configured input tensor info */
   const char *model_config;    /**< The configuration file path for creating model */
   const char *model_save_path;    /**< The file path to save the model */
-  int64_t num_inputs;    /**< The number of input lists, the input is where framework receive the features to train the model, num_inputs indicates how many inputs there are. */
-  int64_t num_labels;    /**< The number of label lists, the label is where framework receive the class to train the model, num_labels indicates how many labels there are. */
-  int64_t num_training_samples;    /**< The number of training sample used to train the model. */
-  int64_t num_validation_samples;    /**< The number of validation sample used to valid the model. */
-  int64_t num_epochs;    /**< The number of repetition of total training and validation sample. subplugin must receive total samples((num_training_samples + num_validation_samples) * num_epochs) */
+  unsigned int num_inputs;    /**< The number of input lists, the input is where framework receive the features to train the model, num_inputs indicates how many inputs there are. */
+  unsigned int num_labels;    /**< The number of label lists, the label is where framework receive the class to train the model, num_labels indicates how many labels there are. */
+  unsigned int num_training_samples;    /**< The number of training sample used to train the model. */
+  unsigned int num_validation_samples;    /**< The number of validation sample used to valid the model. */
+  unsigned int num_epochs;    /**< The number of repetition of total training and validation sample. subplugin must receive total samples((num_training_samples + num_validation_samples) * num_epochs) */
 
   GCond *training_complete_cond;    /**< Tensor trainer wait when receive EOS before model training is complete, subplugin should send signal when model training is complete. */
 } GstTensorTrainerProperties;
@@ -50,8 +50,8 @@ typedef struct _GstTensorTrainerProperties
 typedef struct _GstTensorTrainerFrameworkInfo
 {
   const char *name;    /**< Name of the neural network framework, searchable by FRAMEWORK property. */
-  int is_training_complete;  /**< Check if training is complete, Use int instead of gboolean because this is refered by custom plugins. */
-  int64_t epoch_cnt;    /**< Number of currently completed epochs */
+  unsigned int is_training_complete;  /**< Check if training is complete, Use unsigned int instead of gboolean because this is refered by custom plugins. */
+  unsigned int epoch_cnt;    /**< Number of currently completed epochs */
 } GstTensorTrainerFrameworkInfo;
 
 typedef struct _GstTensorTrainerFramework GstTensorTrainerFramework;