[tensor_trainer] Apply tensor trainer sub-plugin structure
authorhyunil park <hyunil46.park@samsung.com>
Thu, 1 Dec 2022 08:56:57 +0000 (17:56 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 2 Feb 2023 08:30:19 +0000 (17:30 +0900)
- Apply GstTensorTrainerProperties for internal data required by sub-plugin
- Apply GstTensorTrainerFramework for sub-plugin definition
- Apply trainer subpluginType(NNS_SUBPLUGIN_TRAINER) to get, register and
  unregister subplugin
- Add model-save-path property to set path for saving trained model

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/elements/gsttensor_trainer.c
gst/nnstreamer/elements/gsttensor_trainer.h
jni/nnstreamer.mk

index 1a469d9..911aa2f 100644 (file)
@@ -74,6 +74,7 @@ enum
   PROP_0,
   PROP_FRAMEWORK,
   PROP_MODEL_CONFIG,
+  PROP_MODEL_SAVE_PATH,
   PROP_INPUT_DIM,
   PROP_OUTPUT_DIM,
   PROP_INPUT_TYPE,
@@ -91,7 +92,6 @@ static void gst_tensor_trainer_set_property (GObject * object, guint prop_id,
     const GValue * value, GParamSpec * pspec);
 static void gst_tensor_trainer_get_property (GObject * object, guint prop_id,
     GValue * value, GParamSpec * pspec);
-static gboolean gst_tensor_trainer_start (GstBaseTransform * trans);
 static gboolean gst_tensor_trainer_stop (GstBaseTransform * trans);
 static void gst_tensor_trainer_finalize (GObject * object);
 static gboolean gst_tensor_trainer_sink_event (GstBaseTransform * trans,
@@ -112,24 +112,23 @@ static gboolean gst_tensor_trainer_transform_size (GstBaseTransform * trans,
     GstPadDirection direction, GstCaps * caps, gsize size, GstCaps * othercaps,
     gsize * othersize);
 
-static void
-gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
+static void gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
     const GValue * value);
 static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer
     * trainer, const GValue * value);
+static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
+    const GValue * value);
 static void gst_tensor_trainer_set_prop_dimension (GstTensorTrainer * trainer,
     const GValue * value, const gboolean is_input);
 static void gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
     const GValue * value, const gboolean is_input);
-static const GstTensorFilterFramework
-    * gst_tensor_trainer_find_best_framework (const char *names);
 static void gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
     const char *name);
 static void gst_tensor_trainer_create_framework (GstTensorTrainer * trainer);
 static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
     guint index, gboolean is_input);
-static void gst_tensor_trainer_calc_input_tensors_size (GstTensorTrainer *
-    trainer);
+static void gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
+static void gst_tensor_trainer_train_model (GstTensorTrainer * trainer);
 
 /**
  * @brief initialize the tensor_trainer's class
@@ -158,8 +157,6 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
   gstelement_class->change_state =
       GST_DEBUG_FUNCPTR (gst_tensor_trainer_change_state);
 
-  /* Called when the element starts processing */
-  trans_class->start = GST_DEBUG_FUNCPTR (gst_tensor_trainer_start);
   /* Called when the element stop processing */
   trans_class->stop = GST_DEBUG_FUNCPTR (gst_tensor_trainer_stop);
 
@@ -196,6 +193,14 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
+  g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH,
+      g_param_spec_string ("model-save-path", "Model save path",
+          "Path to save the trained model in framework, if model-config "
+          "contains information about the save file, it is ignored",
+          DEFAULT_STR_PROP_VALUE,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
   g_object_class_install_property (gobject_class, PROP_INPUT_DIM,
       g_param_spec_string ("input-dim", "Input dimension",
           "Input tensors dimension from inner array, up to 4 dimensions ?", "",
@@ -277,6 +282,7 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer)
   GST_DEBUG ("<ENTER>");
   trainer->fw_name = g_strdup (DEFAULT_PROP_FRAMEWORK);
   trainer->model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->input_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->input_type = g_strdup (DEFAULT_STR_PROP_VALUE);
@@ -284,7 +290,7 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer)
   trainer->push_output = FALSE;
 
   trainer->fw = NULL;
-  trainer->fw_opened = 0;       /* for test */
+  trainer->fw_created = 0;      /* for test */
   trainer->configured = 0;
   trainer->input_configured = 0;
   trainer->output_configured = 0;
@@ -304,14 +310,14 @@ gst_tensor_trainer_finalize (GObject * object)
 
   g_free (trainer->fw_name);
   g_free (trainer->model_config);
+  g_free (trainer->model_save_path);
   g_free (trainer->input_dimensions);
   g_free (trainer->output_dimensions);
   g_free (trainer->input_type);
   g_free (trainer->output_type);
 
-  GST_DEBUG ("trainer->fw_created=%d", trainer->fw_created);
-  if (trainer->fw_created) {
-    trainer->fw->close (&trainer->prop, &trainer->privateData);
+  if (trainer->fw_created && trainer->fw) {
+    trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
   }
   /* need to free prop data */
 
@@ -337,6 +343,9 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id,
     case PROP_MODEL_CONFIG:
       gst_tensor_trainer_set_prop_model_config_file_path (trainer, value);
       break;
+    case PROP_MODEL_SAVE_PATH:
+      gst_tensor_trainer_set_model_save_path (trainer, value);
+      break;
     case PROP_INPUT_DIM:
       gst_tensor_trainer_set_prop_dimension (trainer, value, TRUE);
       break;
@@ -354,16 +363,16 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id,
       GST_INFO_OBJECT (trainer, "push output: %d", trainer->push_output);
       break;
     case PROP_NUM_INPUTS:
-      trainer->num_inputs = g_value_get_uint (value);
+      trainer->prop.num_inputs = g_value_get_uint (value);
       break;
     case PROP_NUM_LABELS:
-      trainer->num_labels = g_value_get_uint (value);
+      trainer->prop.num_labels = g_value_get_uint (value);
       break;
     case PROP_NUM_TRAINING_SAMPLES:
-      trainer->num_training_samples = g_value_get_uint (value);
+      trainer->prop.num_train_samples = g_value_get_uint (value);
       break;
     case PROP_NUM_VALIDATION_SAMPLES:
-      trainer->num_validation_samples = g_value_get_uint (value);
+      trainer->prop.num_valid_samples = g_value_get_uint (value);
       break;
     default:
       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
@@ -389,6 +398,9 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id,
     case PROP_MODEL_CONFIG:
       g_value_set_string (value, trainer->model_config);
       break;
+    case PROP_MODEL_SAVE_PATH:
+      g_value_set_string (value, trainer->model_save_path);
+      break;
     case PROP_INPUT_DIM:
       g_value_set_string (value, trainer->input_dimensions);
       break;
@@ -405,16 +417,16 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id,
       g_value_set_boolean (value, trainer->push_output);
       break;
     case PROP_NUM_INPUTS:
-      g_value_set_uint (value, trainer->num_inputs);
+      g_value_set_uint (value, trainer->prop.num_inputs);
       break;
     case PROP_NUM_LABELS:
-      g_value_set_uint (value, trainer->num_labels);
+      g_value_set_uint (value, trainer->prop.num_labels);
       break;
     case PROP_NUM_TRAINING_SAMPLES:
-      g_value_set_uint (value, trainer->num_training_samples);
+      g_value_set_uint (value, trainer->prop.num_train_samples);
       break;
     case PROP_NUM_VALIDATION_SAMPLES:
-      g_value_set_uint (value, trainer->num_validation_samples);
+      g_value_set_uint (value, trainer->prop.num_valid_samples);
       break;
     default:
       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
@@ -439,11 +451,12 @@ gst_tensor_trainer_change_state (GstElement * element,
 
     case GST_STATE_CHANGE_READY_TO_PAUSED:
       GST_INFO_OBJECT (trainer, "READY_TO_PAUSED");
+      gst_tensor_trainer_create_model (trainer);
       break;
 
     case GST_STATE_CHANGE_PAUSED_TO_PLAYING:
       GST_INFO_OBJECT (trainer, "PAUSED_TO_PLAYING");
-      /* start or resume model train */
+      gst_tensor_trainer_train_model (trainer);
       break;
 
     default:
@@ -535,28 +548,6 @@ gst_tensor_trainer_src_event (GstBaseTransform * trans, GstEvent * event)
   return GST_BASE_TRANSFORM_CLASS (parent_class)->src_event (trans, event);
 }
 
-
-/**
- * @brief Called when the element starts processing. optional vmethod of BaseTransform
- */
-static gboolean
-gst_tensor_trainer_start (GstBaseTransform * trans)
-{
-  GstTensorTrainer *trainer;
-  trainer = GST_TENSOR_TRAINER_CAST (trans);
-
-  if (trainer->fw_name)
-    gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
-  if (trainer->fw) {
-    /* calc input tensors size */
-    gst_tensor_trainer_calc_input_tensors_size (trainer);
-    /* create, compile */
-    gst_tensor_trainer_create_framework (trainer);
-  }
-
-  return TRUE;
-}
-
 /**
  * @brief Called when the element stops processing. optional vmethod of BaseTransform
  */
@@ -625,15 +616,17 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
 
   /* Prepare tensor to invoke */
   /* Check number of input tensors */
-  if (mem_blocks != trainer->input_meta.num_tensors) {
+  GST_ERROR ("num_tensors: %d", trainer->prop.input_meta.num_tensors);
+  if (mem_blocks != trainer->prop.input_meta.num_tensors) {
     GST_ERROR_OBJECT (trainer, "Invalid memory blocks(%d),"
         "number of input tensors may be (%d)", mem_blocks,
-        trainer->input_meta.num_tensors);
+        trainer->prop.input_meta.num_tensors);
     goto error;
   }
 
   /* Check size of input tensors */
-  for (i = 0; i < trainer->input_meta.num_tensors; i++) {
+  GST_ERROR ("num_tensors: %d", trainer->prop.input_meta.num_tensors);
+  for (i = 0; i < trainer->prop.input_meta.num_tensors; i++) {
     expected = gst_tensor_trainer_get_tensor_size (trainer, i, TRUE);
     if (expected != in_tensors[i].size) {
       GST_ERROR_OBJECT (trainer, "Invalid tensor size (%u'th memory chunk: %zd)"
@@ -693,7 +686,7 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
     }
     /* Call the trainer-subplugin callback, invoke */
     ret =
-        trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData,
+        trainer->fw->invoke (trainer->fw, &trainer->prop, trainer->privateData,
         invoke_tensors, out_tensors);
 
     /* Free out info */
@@ -706,7 +699,7 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
     }
   } else {
     ret =
-        trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData,
+        trainer->fw->invoke (trainer->fw, &trainer->prop, trainer->privateData,
         invoke_tensors, NULL);
   }
 
@@ -780,7 +773,9 @@ gst_tensor_trainer_transform_caps (GstBaseTransform * trans,
   gst_tensors_config_init (&in_config);
   gst_tensors_config_init (&out_config);
   structure = gst_caps_get_structure (caps, 0);
-  gst_tensors_config_from_structure (&in_config, structure);
+
+  if (!gst_tensors_config_from_structure (&in_config, structure))
+    return NULL;
 
   /* Set framerate from input config */
   out_config.rate_n = in_config.rate_n;
@@ -790,7 +785,7 @@ gst_tensor_trainer_transform_caps (GstBaseTransform * trans,
      for trainer->input_meta and trainer->output_meta */
   if (direction == GST_PAD_SRC) {
     if (trainer->input_configured) {
-      gst_tensors_info_copy (&out_config.info, &trainer->input_meta);
+      gst_tensors_info_copy (&out_config.info, &trainer->prop.input_meta);
       configured = TRUE;
     }
   } else {
@@ -857,9 +852,11 @@ gst_tensor_trainer_set_caps (GstBaseTransform * trans, GstCaps * incaps,
 
   gst_tensors_config_init (&in_config);
   structure = gst_caps_get_structure (incaps, 0);
-  gst_tensors_config_from_structure (&in_config, structure);
 
-  if (!gst_tensors_info_is_equal (&in_config.info, &trainer->input_meta)) {
+  if (!gst_tensors_config_from_structure (&in_config, structure))
+    return FALSE;
+
+  if (!gst_tensors_info_is_equal (&in_config.info, &trainer->prop.input_meta)) {
     GST_ERROR_OBJECT (trainer,
         "The input tensors info is different between incaps and set property value. "
         "Please check pipeline's input tensor info and tensor_trainer's set property values"
@@ -895,8 +892,23 @@ gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer * trainer,
 {
   g_free (trainer->model_config);
   trainer->model_config = g_value_dup_string (value);
+  trainer->prop.model_config = trainer->model_config;
   GST_INFO_OBJECT (trainer, "model configuration file path: %s",
-      trainer->model_config);
+      trainer->prop.model_config);
+}
+
+/**
+ * @brief Handle "PROP_MODEL_SAVE_PATH" for set-property
+ */
+static void
+gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
+    const GValue * value)
+{
+  g_free (trainer->model_save_path);
+  trainer->model_save_path = g_value_dup_string (value);
+  trainer->prop.model_save_path = trainer->model_save_path;
+  GST_INFO_OBJECT (trainer, "file path to save the model: %s",
+      trainer->prop.model_save_path);
 }
 
 /**
@@ -921,7 +933,7 @@ gst_tensor_trainer_set_prop_dimension (GstTensorTrainer * trainer,
   }
 
   if (is_input) {
-    info = &trainer->input_meta;
+    info = &trainer->prop.input_meta;
     rank = trainer->input_ranks;
     trainer->input_configured = TRUE;
     g_free (trainer->input_dimensions);
@@ -962,7 +974,6 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
 {
   GstTensorsInfo *info;
   guint num_types;
-  guint i;
 
   if ((is_input && trainer->inputtype_configured) || (!is_input
           && trainer->outputtype_configured)) {
@@ -973,7 +984,7 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
   }
 
   if (is_input) {
-    info = &trainer->input_meta;
+    info = &trainer->prop.input_meta;
     trainer->inputtype_configured = TRUE;
     g_free (trainer->input_type);
     trainer->input_type = g_value_dup_string (value);
@@ -988,9 +999,6 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
 
   num_types =
       gst_tensors_info_parse_types_string (info, g_value_get_string (value));
-  for (i = 0; i < num_types; i++) {
-    trainer->tensors_inputtype[i] = info->info[i].type;
-  }
 
   info->num_tensors = num_types;
 }
@@ -1001,30 +1009,14 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
 static void
 gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name)
 {
-  const GstTensorFilterFramework *fw = NULL;
-  gchar *str;
+  const GstTensorTrainerFramework *fw = NULL;
+
   g_return_if_fail (name != NULL);
   g_return_if_fail (trainer != NULL);
 
   GST_INFO ("find framework: %s", name);
 
-  /* Need to add trainer type to subpluginType */
-  fw = get_subplugin (NNS_SUBPLUGIN_FILTER, name);
-
-  if (fw == NULL) {
-    /*Get sub-plugin priority from ini file and find sub-plugin */
-    str = nnsconf_get_custom_value_string (name, "subplugin_priority");
-    fw = gst_tensor_trainer_find_best_framework (str);
-    g_free (str);
-  }
-
-  if (fw == NULL) {
-    /* Check the filter-alias from ini file */
-    str = nnsconf_get_custom_value_string ("filter-aliases", name);
-    fw = gst_tensor_trainer_find_best_framework (str);
-    g_free (str);
-  }
-
+  fw = get_subplugin (NNS_SUBPLUGIN_TRAINER, name);
   if (fw) {
     GST_INFO_OBJECT (trainer, "find framework %s:%p", trainer->fw_name, fw);
     trainer->fw = fw;
@@ -1039,57 +1031,28 @@ gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name)
 static void
 gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
 {
-  if (!trainer->fw || trainer->fw_opened) {
+  g_return_if_fail (trainer != NULL);
+
+  if (!trainer->fw || trainer->fw_created) {
     GST_ERROR_OBJECT (trainer, "fw is not opened(%d) or fw is null(%p)",
-        trainer->fw_opened, trainer->fw);
+        trainer->fw_created, trainer->fw);
     return;
   }
 
   /* For test */
-  if (!trainer->fw->open) {     /* fw->create, create model with configuration file */
+  if (!trainer->fw->create) {   /* fw->create, create model with configuration file */
     GST_ERROR_OBJECT (trainer, "Could not find fw->create");
     return;
   }
   /* Test code, need to create with load ini file */
   GST_ERROR ("%p", trainer->privateData);
-  if (trainer->fw->open (&trainer->prop, &trainer->privateData) >= 0)
+  if (trainer->fw->create (trainer->fw, &trainer->prop,
+          &trainer->privateData) >= 0)
     trainer->fw_created = TRUE;
   GST_ERROR ("%p", trainer->privateData);
 }
 
 /**
- * @brief Find sub-plugin trainer given the name list
- */
-static const GstTensorFilterFramework *
-gst_tensor_trainer_find_best_framework (const char *names)
-{
-  const GstTensorFilterFramework *fw = NULL;    /* need to change to GstTensorTrainerFramework */
-  gchar **subplugins;
-  guint i, len;
-
-  if (names == NULL || names[0] == '\0')
-    return NULL;
-
-  subplugins = g_strsplit_set (names, " ,;", -1);
-
-  len = g_strv_length (subplugins);
-
-  for (i = 0; i < len; i++) {
-    if (strlen (g_strstrip (subplugins[i])) == 0)
-      continue;
-
-    fw = get_subplugin (NNS_SUBPLUGIN_FILTER, subplugins[i]);   /* need to add trainer type to subpluginType */
-    if (fw) {
-      GST_INFO ("i = %d found %s", i, subplugins[i]);
-      break;
-    }
-  }
-  g_strfreev (subplugins);
-
-  return fw;
-}
-
-/**
  * @brief Calculate tensor buffer size
  */
 gsize
@@ -1099,7 +1062,7 @@ gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer, guint index,
   GstTensorsInfo *info;
 
   if (is_input)
-    info = &trainer->input_meta;
+    info = &trainer->prop.input_meta;
   else
     info = &trainer->output_meta;
 
@@ -1139,40 +1102,90 @@ gst_tensor_trainer_transform_size (GstBaseTransform * trans,
 }
 
 /**
- * @brief Calculate the size of input tensors
+ * @brief Create model
  */
 static void
-gst_tensor_trainer_calc_input_tensors_size (GstTensorTrainer * trainer)
+gst_tensor_trainer_create_model (GstTensorTrainer * trainer)
 {
-  GstTensorsInfo *info;
-  guint i, j, idx, size[NNS_TENSOR_SIZE_LIMIT] = { 1, };
-
-  const guint get_tensor_size[] = {
-    [_NNS_INT32] = 4,
-    [_NNS_UINT32] = 4,
-    [_NNS_INT16] = 2,
-    [_NNS_UINT16] = 2,
-    [_NNS_INT8] = 1,
-    [_NNS_UINT8] = 1,
-    [_NNS_FLOAT64] = 8,
-    [_NNS_FLOAT32] = 4,
-    [_NNS_INT64] = 8,
-    [_NNS_UINT64] = 8,
-    [_NNS_FLOAT16] = 2,
-  };
+  g_return_if_fail (trainer != NULL);
+  GST_DEBUG_OBJECT (trainer, "called");
+
+  if (trainer->fw_name)
+    gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
+  if (trainer->fw) {
+    /* model create and compile */
+    gst_tensor_trainer_create_framework (trainer);
+  }
+}
 
+/**
+ * @brief Train model
+ */
+static void
+gst_tensor_trainer_train_model (GstTensorTrainer * trainer)
+{
+  gint ret = -1;
   g_return_if_fail (trainer != NULL);
+  g_return_if_fail (trainer->fw != NULL);
+  g_return_if_fail (trainer->fw->train != NULL);
 
-  info = &trainer->input_meta;
-  idx = 0;
-  for (i = 0; i < info->num_tensors; i++) {
-    for (j = 0; j < NNS_TENSOR_RANK_LIMIT; j++) {
-      size[idx] *= info->info[i].dimension[j];
-    }
-    trainer->tensors_inputsize[idx] =
-        size[idx] * get_tensor_size[trainer->tensors_inputtype[idx]];
-    GST_DEBUG_OBJECT (trainer, "trainer->tensors_inputsize[%d]= %d", idx,
-        trainer->tensors_inputsize[idx]);
-    idx++;
+  GST_DEBUG_OBJECT (trainer, "called");
+  ret = trainer->fw->train (trainer->fw, &trainer->prop, trainer->privateData);
+  if (ret != 0) {
+    GST_ERROR_OBJECT (trainer, "model train is failed");
+  }
+}
+
+/**
+ * @brief Trainer's sub-plugin should call this function to register itself.
+ * @param[in] ttsp tensor_trainer sub-plugin to be registered.
+ * @return TRUE if registered. FALSE is failed or duplicated.
+ */
+int
+nnstreamer_trainer_probe (GstTensorTrainerFramework * ttsp)
+{
+  GstTensorTrainerFrameworkInfo info;
+  GstTensorTrainerProperties prop;
+  const char *name = NULL;
+  int ret = 0;
+
+  g_return_val_if_fail (ttsp != NULL, 0);
+
+  memset (&prop, 0, sizeof (GstTensorTrainerProperties));
+  gst_tensors_info_init (&prop.input_meta);
+
+  if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
+    GST_ERROR ("getFrameworkInfo() failed");
+    return FALSE;
   }
+  name = info.name;
+
+  return register_subplugin (NNS_SUBPLUGIN_TRAINER, name, ttsp);
+}
+
+/**
+ * @brief Trainer's sub-plugin may call this to unregister itself.
+ * @param[in] ttsp tensor_trainer sub-plugin to be unregistered.
+ * @return TRUE if unregistered. FALSE is failed.
+ */
+int
+nnstreamer_trainer_exit (GstTensorTrainerFramework * ttsp)
+{
+  GstTensorTrainerFrameworkInfo info;
+  GstTensorTrainerProperties prop;
+  const char *name = NULL;
+  int ret = 0;
+
+  g_return_val_if_fail (ttsp != NULL, 0);
+
+  memset (&prop, 0, sizeof (GstTensorTrainerProperties));
+  gst_tensors_info_init (&prop.input_meta);
+
+  if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
+    GST_ERROR ("getFrameworkInfo() failed");
+    return FALSE;
+  }
+  name = info.name;
+
+  return unregister_subplugin (NNS_SUBPLUGIN_TRAINER, name);
 }
index ea90c84..594f56f 100644 (file)
@@ -20,7 +20,7 @@
 #include <tensor_common.h>
 
 #include <nnstreamer_plugin_api_util.h>
-#include <nnstreamer_plugin_api_filter.h>
+#include <nnstreamer_plugin_api_trainer.h>
 
 G_BEGIN_DECLS
 #define GST_TYPE_TENSOR_TRAINER \
@@ -37,7 +37,6 @@ G_BEGIN_DECLS
 typedef struct _GstTensorTrainer GstTensorTrainer;
 typedef struct _GstTensorTrainerClass GstTensorTrainerClass;
 
-
 /**
  * @brief GstTensorTrainer data structure
  */
@@ -47,42 +46,29 @@ struct _GstTensorTrainer
 
   gchar *fw_name;
   gchar *model_config;
+  gchar *model_save_path;
   gchar *input_dimensions;
   gchar *output_dimensions;
   gchar *input_type;
   gchar *output_type;
   gboolean push_output;
-  unsigned int num_inputs;
-  unsigned int num_labels;
-  unsigned int num_training_samples;
-  unsigned int num_validation_samples;
-
-  GstTensorsInfo input_meta;
-  GstTensorsInfo output_meta;
 
   gboolean configured;
-
   int input_configured;
   int output_configured;
   int inputtype_configured;
   int outputtype_configured;
   unsigned int input_ranks[NNS_TENSOR_SIZE_LIMIT];
   unsigned int output_ranks[NNS_TENSOR_SIZE_LIMIT];
-
-  tensor_type tensors_inputtype[NNS_TENSOR_SIZE_LIMIT];
-  unsigned int tensors_inputsize[NNS_TENSOR_SIZE_LIMIT];
+  GstTensorsInfo output_meta;
 
   /* draft */
-  int fw_opened;
-  int fw_compiled;
-  int fw_fitted;
   int fw_created;
   int fw_stop;
-  int fw_paused;
 
   void *privateData; /**< NNFW plugin's private data is stored here */
-  const GstTensorFilterFramework *fw;   /* for test, need to make */
-  GstTensorFilterProperties prop; /**< NNFW plugin's properties */
+  const GstTensorTrainerFramework *fw;  /* for test, need to make */
+  GstTensorTrainerProperties prop; /**< NNFW plugin's properties */
 };
 
 /**
index 0f5e9f9..adcb060 100644 (file)
@@ -67,6 +67,7 @@ NNSTREAMER_PLUGINS_SRCS := \
     $(NNSTREAMER_GST_HOME)/elements/gsttensor_sparseenc.c \
     $(NNSTREAMER_GST_HOME)/elements/gsttensor_sparseutil.c \
     $(NNSTREAMER_GST_HOME)/elements/gsttensor_split.c \
+    $(NNSTREAMER_GST_HOME)/elements/gsttensor_trainer.c \
     $(NNSTREAMER_GST_HOME)/elements/gsttensor_transform.c \
     $(NNSTREAMER_GST_HOME)/tensor_filter/tensor_filter.c