[tensor_trainer] Add properties for training in sub-plugin
authorhyunil park <hyunil46.park@samsung.com>
Mon, 31 Oct 2022 04:59:29 +0000 (13:59 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 2 Feb 2023 08:30:19 +0000 (17:30 +0900)
- Add model-config to set model configuration file,
  it is used to configure the model to be trained in framework
- Add push-output to push output tensors, default value is false
- Add num-inputs to set how many inputs are received,
  an input in a tensor can have one or more features data
- Add num-labels to set how many labels are received,
  a label in a tensor can have one or more classes data
- Add num-training-samples to set how many samples are taken for training model
- Add num-validation-samples to set how many samples are taken for validation model

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/elements/gsttensor_trainer.c
gst/nnstreamer/elements/gsttensor_trainer.h
gst/nnstreamer/elements/meson.build

index 69223a2..1a469d9 100644 (file)
  *
  * ## Example launch line
  * |[
- * gst-launch-1.0 videotestsrc !
- *    video/x-raw, format=RGB, width=640, height=480 ! tensor_converter ! 
- *    tensor_trainer input=3:640:480 inputtype=uint8 output=1:1:1:1 outputtype=uint8 !
- *    tensor_sink
+ * gst-launch-1.0 videotestsrc ! video/x-raw, format=RGB, width=640, height=480 !
+ * tensor_converter ! tensor_transform mode=typecast option=float32 !
+ * tensor_trainer framework=nntrainer model-config=/usr/bin/model.ini
+ * push-output=false input=3:640:480:1 inputtype=float32
+ * output=1:1:1:1 outputtype=float32 ! tensor_sink
  * ]|
  *
  */
@@ -26,6 +27,8 @@
 #include <nnstreamer_subplugin.h>
 #include <nnstreamer_util.h>
 #include "gsttensor_trainer.h"
+#include <unistd.h>
+#include <sys/syscall.h>
 
 /**
  * @brief Default caps string for both sink and source pad.
@@ -56,7 +59,7 @@ G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_BASE_TRANSFORM);
 /**
  * @brief Default framework property value
  */
-#define DEFAULT_FRAMEWORK "nntrainer"
+#define DEFAULT_PROP_FRAMEWORK "nntrainer"
 
 /**
  * @brief Default string property value 
@@ -70,12 +73,20 @@ enum
 {
   PROP_0,
   PROP_FRAMEWORK,
+  PROP_MODEL_CONFIG,
   PROP_INPUT_DIM,
   PROP_OUTPUT_DIM,
   PROP_INPUT_TYPE,
-  PROP_OUTPUTTYPE
+  PROP_OUTPUT_TYPE,
+  PROP_PUSH_OUTPUT,
+  PROP_NUM_INPUTS,              /* number of input list */
+  PROP_NUM_LABELS,              /* number of label list */
+  PROP_NUM_TRAINING_SAMPLES,    /*number of training data */
+  PROP_NUM_VALIDATION_SAMPLES,  /*number of validation data */
 };
 
+#define NNS_TENSOR_TYPES 11
+
 static void gst_tensor_trainer_set_property (GObject * object, guint prop_id,
     const GValue * value, GParamSpec * pspec);
 static void gst_tensor_trainer_get_property (GObject * object, guint prop_id,
@@ -102,18 +113,23 @@ static gboolean gst_tensor_trainer_transform_size (GstBaseTransform * trans,
     gsize * othersize);
 
 static void
-gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer, const gchar * fw_name);
+gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
+    const GValue * value);
+static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer
+    * trainer, const GValue * value);
 static void gst_tensor_trainer_set_prop_dimension (GstTensorTrainer * trainer,
     const GValue * value, const gboolean is_input);
 static void gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
     const GValue * value, const gboolean is_input);
 static const GstTensorFilterFramework
-    *gst_tensor_trainer_find_best_framework (const char *names);
+    * gst_tensor_trainer_find_best_framework (const char *names);
 static void gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
     const char *name);
 static void gst_tensor_trainer_create_framework (GstTensorTrainer * trainer);
 static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
     guint index, gboolean is_input);
+static void gst_tensor_trainer_calc_input_tensors_size (GstTensorTrainer *
+    trainer);
 
 /**
  * @brief initialize the tensor_trainer's class
@@ -166,8 +182,17 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
 
   /* Install properties for tensor_trainer */
   g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
-      g_param_spec_string ("framework", "Framework", "Neural network framework",
-          DEFAULT_FRAMEWORK,
+      g_param_spec_string ("framework", "Framework",
+          "Neural network framework to be used for model training",
+          DEFAULT_PROP_FRAMEWORK,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_MODEL_CONFIG,
+      g_param_spec_string ("model-config", "Model configuration file path",
+          "Model configuration file is used to configure the model "
+          "to be trained in neural network framework, set the file path",
+          DEFAULT_STR_PROP_VALUE,
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
@@ -195,6 +220,43 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
           G_PARAM_STATIC_STRINGS));
 
+  g_object_class_install_property (gobject_class, PROP_PUSH_OUTPUT,
+      g_param_spec_boolean ("push-output", "Push output tensor",
+          "Add output tensors to output GstBuffer", FALSE,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_NUM_INPUTS,
+      g_param_spec_uint ("num-inputs", "Number of inputs",
+          "An input in a tensor can have one or more features data,"
+          "set how many inputs are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_NUM_LABELS,
+      g_param_spec_uint ("num-labels", "Number of labels",
+          "A label in a tensor can have one or more classes data,"
+          "set how many labes are recevied", 0, NNS_TENSOR_SIZE_LIMIT, 1,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_NUM_TRAINING_SAMPLES,
+      g_param_spec_uint ("num-training-samples", "Number of training samples",
+          "A sample can consist of muliple inputs and labels in tensors of a gstbuffer"
+          ", set how many samples are taken for training model", 0, G_MAXUINT,
+          1,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_NUM_VALIDATION_SAMPLES,
+      g_param_spec_uint ("num-validation-samples",
+          "Number of validation samples",
+          "A sample can consist of muliple inputs and labels in tensors of a gstbuffer"
+          ", set how many samples are taken for validation model",
+          0, G_MAXUINT, 1,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
   gst_element_class_set_details_simple (gstelement_class, "TensorTrainer",
       "Trainer/Tensor", "Train tensor data using NN Frameworks",
       "Samsung Electronics Co., Ltd.");
@@ -213,14 +275,16 @@ static void
 gst_tensor_trainer_init (GstTensorTrainer * trainer)
 {
   GST_DEBUG ("<ENTER>");
-  trainer->fw_name = g_strdup (DEFAULT_FRAMEWORK);
+  trainer->fw_name = g_strdup (DEFAULT_PROP_FRAMEWORK);
+  trainer->model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->input_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->input_type = g_strdup (DEFAULT_STR_PROP_VALUE);
   trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->push_output = FALSE;
 
   trainer->fw = NULL;
-  trainer->fw_opened = 0; /* for test */
+  trainer->fw_opened = 0;       /* for test */
   trainer->configured = 0;
   trainer->input_configured = 0;
   trainer->output_configured = 0;
@@ -239,11 +303,18 @@ gst_tensor_trainer_finalize (GObject * object)
   trainer = GST_TENSOR_TRAINER (object);
 
   g_free (trainer->fw_name);
+  g_free (trainer->model_config);
   g_free (trainer->input_dimensions);
   g_free (trainer->output_dimensions);
   g_free (trainer->input_type);
   g_free (trainer->output_type);
 
+  GST_DEBUG ("trainer->fw_created=%d", trainer->fw_created);
+  if (trainer->fw_created) {
+    trainer->fw->close (&trainer->prop, &trainer->privateData);
+  }
+  /* need to free prop data */
+
   G_OBJECT_CLASS (parent_class)->finalize (object);
 }
 
@@ -261,7 +332,10 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id,
   switch (prop_id) {
 
     case PROP_FRAMEWORK:
-      gst_tensor_trainer_set_prop_framework (trainer, g_value_get_string (value));
+      gst_tensor_trainer_set_prop_framework (trainer, value);
+      break;
+    case PROP_MODEL_CONFIG:
+      gst_tensor_trainer_set_prop_model_config_file_path (trainer, value);
       break;
     case PROP_INPUT_DIM:
       gst_tensor_trainer_set_prop_dimension (trainer, value, TRUE);
@@ -275,6 +349,22 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id,
     case PROP_OUTPUT_TYPE:
       gst_tensor_trainer_set_prop_type (trainer, value, FALSE);
       break;
+    case PROP_PUSH_OUTPUT:
+      trainer->push_output = g_value_get_boolean (value);
+      GST_INFO_OBJECT (trainer, "push output: %d", trainer->push_output);
+      break;
+    case PROP_NUM_INPUTS:
+      trainer->num_inputs = g_value_get_uint (value);
+      break;
+    case PROP_NUM_LABELS:
+      trainer->num_labels = g_value_get_uint (value);
+      break;
+    case PROP_NUM_TRAINING_SAMPLES:
+      trainer->num_training_samples = g_value_get_uint (value);
+      break;
+    case PROP_NUM_VALIDATION_SAMPLES:
+      trainer->num_validation_samples = g_value_get_uint (value);
+      break;
     default:
       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
       break;
@@ -296,6 +386,9 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id,
     case PROP_FRAMEWORK:
       g_value_set_string (value, trainer->fw_name);
       break;
+    case PROP_MODEL_CONFIG:
+      g_value_set_string (value, trainer->model_config);
+      break;
     case PROP_INPUT_DIM:
       g_value_set_string (value, trainer->input_dimensions);
       break;
@@ -308,6 +401,21 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id,
     case PROP_OUTPUT_TYPE:
       g_value_set_string (value, trainer->output_type);
       break;
+    case PROP_PUSH_OUTPUT:
+      g_value_set_boolean (value, trainer->push_output);
+      break;
+    case PROP_NUM_INPUTS:
+      g_value_set_uint (value, trainer->num_inputs);
+      break;
+    case PROP_NUM_LABELS:
+      g_value_set_uint (value, trainer->num_labels);
+      break;
+    case PROP_NUM_TRAINING_SAMPLES:
+      g_value_set_uint (value, trainer->num_training_samples);
+      break;
+    case PROP_NUM_VALIDATION_SAMPLES:
+      g_value_set_uint (value, trainer->num_validation_samples);
+      break;
     default:
       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
       break;
@@ -440,6 +548,8 @@ gst_tensor_trainer_start (GstBaseTransform * trans)
   if (trainer->fw_name)
     gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
   if (trainer->fw) {
+    /* calc input tensors size */
+    gst_tensor_trainer_calc_input_tensors_size (trainer);
     /* create, compile */
     gst_tensor_trainer_create_framework (trainer);
   }
@@ -483,7 +593,10 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
   GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
   GstTensorMetaInfo out_meta[NNS_TENSOR_SIZE_LIMIT];
 
+  pid_t pid = getpid ();
+  pid_t tid = syscall (SYS_gettid);
   trainer = GST_TENSOR_TRAINER_CAST (trans);
+  GST_ERROR_OBJECT (trainer, "pid: %d, tid: %d", pid, tid);
 
   /* Get all input tensors from inbuf */
   mem_blocks = gst_buffer_n_memory (inbuf);
@@ -507,6 +620,7 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
 
     in_tensors[i].data = in_info[i].data + header_size;
     in_tensors[i].size = in_info[i].size - header_size;
+    GST_INFO ("tensor size: %zd", in_tensors[i].size);
   }
 
   /* Prepare tensor to invoke */
@@ -528,66 +642,78 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
     }
     /* Copy to data pointer */
     invoke_tensors[i] = in_tensors[i];
+    GST_INFO ("in_tensors[%d].size= %zd", i, in_tensors[i].size);
+    GST_INFO ("in_tensors[%d].data: %p", i, in_tensors[i].data);
+    GST_INFO ("invoke_tensors[%d].size= %zd", i, invoke_tensors[i].size);
+    GST_INFO ("invoke_tensors[%d].data: %p", i, invoke_tensors[i].data);
   }
 
   /* Prepare output tensor */
-  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
-    out_tensors[i].data = NULL;
-    out_tensors[i].size =
-        gst_tensor_trainer_get_tensor_size (trainer, i, FALSE);
-
-    /* Get header size */
-    header_size = 0;
-    out_flexible =
-        gst_tensor_pad_caps_is_flexible (GST_BASE_TRANSFORM_SRC_PAD (trans));
-    if (out_flexible) {
-      gst_tensor_info_convert_to_meta (&trainer->output_meta.info[i],
-          &out_meta[i]);
-      header_size = gst_tensor_meta_info_get_header_size (&out_meta[i]);
-      GST_INFO ("flexible header size:%zd", header_size);
-    } else {
-      GST_INFO ("not flexible header size:%zd", header_size);
-    }
+  if (trainer->push_output) {
+    for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+      out_tensors[i].data = NULL;
+      out_tensors[i].size =
+          gst_tensor_trainer_get_tensor_size (trainer, i, FALSE);
+
+      /* Get header size */
+      header_size = 0;
+      out_flexible =
+          gst_tensor_pad_caps_is_flexible (GST_BASE_TRANSFORM_SRC_PAD (trans));
+      if (out_flexible) {
+        gst_tensor_info_convert_to_meta (&trainer->output_meta.info[i],
+            &out_meta[i]);
+        header_size = gst_tensor_meta_info_get_header_size (&out_meta[i]);
+        GST_INFO ("flexible header size:%zd", header_size);
+      } else {
+        GST_INFO ("not flexible header size:%zd", header_size);
+      }
 
-    out_mem[i] =
-        gst_allocator_alloc (NULL, out_tensors[i].size + header_size, NULL);
-    if (!out_mem[i]) {
-      GST_ERROR_OBJECT (trainer, "Failed to allocate memory");
-      goto error;
-    }
+      out_mem[i] =
+          gst_allocator_alloc (NULL, out_tensors[i].size + header_size, NULL);
+      if (!out_mem[i]) {
+        GST_ERROR_OBJECT (trainer, "Failed to allocate memory");
+        goto error;
+      }
 
-    if (!gst_memory_map (out_mem[i], &out_info[i], GST_MAP_WRITE)) {
-      GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
-      goto error;
-    }
+      if (!gst_memory_map (out_mem[i], &out_info[i], GST_MAP_WRITE)) {
+        GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
+        goto error;
+      }
 
-    out_tensors[i].data = out_info[i].data + header_size;
+      out_tensors[i].data = out_info[i].data + header_size;
 
-    /* Append header */
-    if (out_flexible) {
-      if (!gst_tensor_meta_info_update_header (&out_meta[i], out_info[i].data)) {
-        GST_ERROR_OBJECT (trainer, "Failed to update header ");
-        goto error;
+      /* Append header */
+      if (out_flexible) {
+        if (!gst_tensor_meta_info_update_header (&out_meta[i],
+                out_info[i].data)) {
+          GST_ERROR_OBJECT (trainer, "Failed to update header ");
+          goto error;
+        }
       }
     }
+    /* Call the trainer-subplugin callback, invoke */
+    ret =
+        trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData,
+        invoke_tensors, out_tensors);
+
+    /* Free out info */
+    for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+      if (out_mem[i])
+        gst_memory_unmap (out_mem[i], &out_info[i]);
+      //if (ret != 0) {
+      //  gst_allocator_free (out_mem[i]->allocator, out_mem[i]);
+      //}
+    }
+  } else {
+    ret =
+        trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData,
+        invoke_tensors, NULL);
   }
 
-  /* Call the trainer-subplugin callback, invoke */
-  ret =
-      trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData,
-      invoke_tensors, out_tensors);
-
-  /* Free map info and handle */
+  /* Free in info */
   for (i = 0; i < mem_blocks; i++)
     gst_memory_unmap (in_mem[i], &in_info[i]);
 
-  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
-    gst_memory_unmap (out_mem[i], &out_info[i]);
-    //if (ret != 0) {
-    //  gst_allocator_free (out_mem[i]->allocator, out_mem[i]);
-    //}
-  }
-
   if (ret < 0) {
     GST_ERROR_OBJECT (trainer, "Invoke error");
     // return GST_FLOW_ERROR;
@@ -596,13 +722,15 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
     // return GST_BASE_TRANSFORM_FLOW_DROPPED;
   }
 
-  GST_INFO ("out buffer size : %zd", gst_buffer_get_size (outbuf));
   /*Update result */
-  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
-    /* append the memory block to outbuf */
-    gst_buffer_append_memory (outbuf, out_mem[i]);
+  GST_INFO ("out buffer size : %zd", gst_buffer_get_size (outbuf));
+  if (trainer->push_output) {
+    for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+      /* append the memory block to outbuf */
+      gst_buffer_append_memory (outbuf, out_mem[i]);
+    }
+    GST_INFO ("after out buffer size : %zd", gst_buffer_get_size (outbuf));
   }
-  GST_INFO ("after out buffer size : %zd", gst_buffer_get_size (outbuf));
 
   return GST_FLOW_OK;
 
@@ -678,7 +806,8 @@ gst_tensor_trainer_transform_caps (GstBaseTransform * trans,
   else
     result = gst_caps_from_string (CAPS_STRING);
 
-  GST_DEBUG_OBJECT (trans, "caps intersect without filter %" GST_PTR_FORMAT, result);
+  GST_DEBUG_OBJECT (trans, "caps intersect without filter %" GST_PTR_FORMAT,
+      result);
 
   if (filter) {
     GstCaps *intersection;
@@ -747,16 +876,30 @@ gst_tensor_trainer_set_caps (GstBaseTransform * trans, GstCaps * incaps,
  * @brief Handle "PROP_FRAMEWORK" for set-property
  */
 static void
-gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer, const gchar * fw_name)
+gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
+    const GValue * value)
 {
   g_free (trainer->fw_name);
-  trainer->fw_name = g_strdup (fw_name);
+  trainer->fw_name = g_value_dup_string (value);
   GST_INFO_OBJECT (trainer, "framework: %s", trainer->fw_name);
 
   /** @todo Check valid framework */
 }
 
 /**
+ * @brief Handle "PROP_MODEL_CONFIG" for set-property
+ */
+static void
+gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer * trainer,
+    const GValue * value)
+{
+  g_free (trainer->model_config);
+  trainer->model_config = g_value_dup_string (value);
+  GST_INFO_OBJECT (trainer, "model configuration file path: %s",
+      trainer->model_config);
+}
+
+/**
  * @brief Handle "PROP_INPUT_DIM" and "PROP_OUTPUT_DIM" for set-property
  */
 static void
@@ -819,6 +962,7 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
 {
   GstTensorsInfo *info;
   guint num_types;
+  guint i;
 
   if ((is_input && trainer->inputtype_configured) || (!is_input
           && trainer->outputtype_configured)) {
@@ -844,6 +988,9 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
 
   num_types =
       gst_tensors_info_parse_types_string (info, g_value_get_string (value));
+  for (i = 0; i < num_types; i++) {
+    trainer->tensors_inputtype[i] = info->info[i].type;
+  }
 
   info->num_tensors = num_types;
 }
@@ -857,6 +1004,7 @@ gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name)
   const GstTensorFilterFramework *fw = NULL;
   gchar *str;
   g_return_if_fail (name != NULL);
+  g_return_if_fail (trainer != NULL);
 
   GST_INFO ("find framework: %s", name);
 
@@ -903,8 +1051,10 @@ gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
     return;
   }
   /* Test code, need to create with load ini file */
+  GST_ERROR ("%p", trainer->privateData);
   if (trainer->fw->open (&trainer->prop, &trainer->privateData) >= 0)
     trainer->fw_created = TRUE;
+  GST_ERROR ("%p", trainer->privateData);
 }
 
 /**
@@ -913,7 +1063,7 @@ gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
 static const GstTensorFilterFramework *
 gst_tensor_trainer_find_best_framework (const char *names)
 {
-  const GstTensorFilterFramework *fw = NULL; /* need to change to GstTensorTrainerFramework */
+  const GstTensorFilterFramework *fw = NULL;    /* need to change to GstTensorTrainerFramework */
   gchar **subplugins;
   guint i, len;
 
@@ -928,7 +1078,7 @@ gst_tensor_trainer_find_best_framework (const char *names)
     if (strlen (g_strstrip (subplugins[i])) == 0)
       continue;
 
-    fw = get_subplugin (NNS_SUBPLUGIN_FILTER, subplugins[i]); /* need to add trainer type to subpluginType */
+    fw = get_subplugin (NNS_SUBPLUGIN_FILTER, subplugins[i]);   /* need to add trainer type to subpluginType */
     if (fw) {
       GST_INFO ("i = %d found %s", i, subplugins[i]);
       break;
@@ -987,3 +1137,42 @@ gst_tensor_trainer_transform_size (GstBaseTransform * trans,
 
   return TRUE;
 }
+
+/**
+ * @brief Calculate the size of input tensors
+ */
+static void
+gst_tensor_trainer_calc_input_tensors_size (GstTensorTrainer * trainer)
+{
+  GstTensorsInfo *info;
+  guint i, j, idx, size[NNS_TENSOR_SIZE_LIMIT] = { 1, };
+
+  const guint get_tensor_size[] = {
+    [_NNS_INT32] = 4,
+    [_NNS_UINT32] = 4,
+    [_NNS_INT16] = 2,
+    [_NNS_UINT16] = 2,
+    [_NNS_INT8] = 1,
+    [_NNS_UINT8] = 1,
+    [_NNS_FLOAT64] = 8,
+    [_NNS_FLOAT32] = 4,
+    [_NNS_INT64] = 8,
+    [_NNS_UINT64] = 8,
+    [_NNS_FLOAT16] = 2,
+  };
+
+  g_return_if_fail (trainer != NULL);
+
+  info = &trainer->input_meta;
+  idx = 0;
+  for (i = 0; i < info->num_tensors; i++) {
+    for (j = 0; j < NNS_TENSOR_RANK_LIMIT; j++) {
+      size[idx] *= info->info[i].dimension[j];
+    }
+    trainer->tensors_inputsize[idx] =
+        size[idx] * get_tensor_size[trainer->tensors_inputtype[idx]];
+    GST_DEBUG_OBJECT (trainer, "trainer->tensors_inputsize[%d]= %d", idx,
+        trainer->tensors_inputsize[idx]);
+    idx++;
+  }
+}
index 258ebc9..ea90c84 100644 (file)
@@ -46,10 +46,17 @@ struct _GstTensorTrainer
   GstBaseTransform element;
 
   gchar *fw_name;
+  gchar *model_config;
   gchar *input_dimensions;
   gchar *output_dimensions;
   gchar *input_type;
   gchar *output_type;
+  gboolean push_output;
+  unsigned int num_inputs;
+  unsigned int num_labels;
+  unsigned int num_training_samples;
+  unsigned int num_validation_samples;
+
   GstTensorsInfo input_meta;
   GstTensorsInfo output_meta;
 
@@ -62,6 +69,9 @@ struct _GstTensorTrainer
   unsigned int input_ranks[NNS_TENSOR_SIZE_LIMIT];
   unsigned int output_ranks[NNS_TENSOR_SIZE_LIMIT];
 
+  tensor_type tensors_inputtype[NNS_TENSOR_SIZE_LIMIT];
+  unsigned int tensors_inputsize[NNS_TENSOR_SIZE_LIMIT];
+
   /* draft */
   int fw_opened;
   int fw_compiled;
index e5c065b..16577c7 100644 (file)
@@ -17,7 +17,7 @@ nnstreamer_sources += files(
   'gsttensor_sparseenc.c',
   'gsttensor_sparseutil.c',
   'gsttensor_split.c',
-  'gsttensor_transform.c'
+  'gsttensor_transform.c',
   'gsttensor_trainer.c'
 )