From: hyunil park Date: Thu, 1 Dec 2022 08:56:57 +0000 (+0900) Subject: [tensor_trainer] Apply tensor trainer sub-plugin structure X-Git-Tag: accepted/tizen/unified/20230215.155633~11 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7b76b797c768bb34bae32d68eed8e675b9953c17;p=platform%2Fupstream%2Fnnstreamer.git [tensor_trainer] Apply tensor trainer sub-plugin structure - Apply GstTensorTrainerProperties for internal data required by sub-plugin - Apply GstTensorTrainerFramework for sub-plugin definition - Apply trainer subpluginType(NNS_SUBPLUGIN_TRAINER) to get, register and unregister subplugin - Add model-save-path property to set path for saving trained model Signed-off-by: hyunil park --- diff --git a/gst/nnstreamer/elements/gsttensor_trainer.c b/gst/nnstreamer/elements/gsttensor_trainer.c index 1a469d9..911aa2f 100644 --- a/gst/nnstreamer/elements/gsttensor_trainer.c +++ b/gst/nnstreamer/elements/gsttensor_trainer.c @@ -74,6 +74,7 @@ enum PROP_0, PROP_FRAMEWORK, PROP_MODEL_CONFIG, + PROP_MODEL_SAVE_PATH, PROP_INPUT_DIM, PROP_OUTPUT_DIM, PROP_INPUT_TYPE, @@ -91,7 +92,6 @@ static void gst_tensor_trainer_set_property (GObject * object, guint prop_id, const GValue * value, GParamSpec * pspec); static void gst_tensor_trainer_get_property (GObject * object, guint prop_id, GValue * value, GParamSpec * pspec); -static gboolean gst_tensor_trainer_start (GstBaseTransform * trans); static gboolean gst_tensor_trainer_stop (GstBaseTransform * trans); static void gst_tensor_trainer_finalize (GObject * object); static gboolean gst_tensor_trainer_sink_event (GstBaseTransform * trans, @@ -112,24 +112,23 @@ static gboolean gst_tensor_trainer_transform_size (GstBaseTransform * trans, GstPadDirection direction, GstCaps * caps, gsize size, GstCaps * othercaps, gsize * othersize); -static void -gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer, +static void gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer, const GValue * value); static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer * trainer, const GValue * value); +static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer, + const GValue * value); static void gst_tensor_trainer_set_prop_dimension (GstTensorTrainer * trainer, const GValue * value, const gboolean is_input); static void gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer, const GValue * value, const gboolean is_input); -static const GstTensorFilterFramework - * gst_tensor_trainer_find_best_framework (const char *names); static void gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name); static void gst_tensor_trainer_create_framework (GstTensorTrainer * trainer); static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer, guint index, gboolean is_input); -static void gst_tensor_trainer_calc_input_tensors_size (GstTensorTrainer * - trainer); +static void gst_tensor_trainer_create_model (GstTensorTrainer * trainer); +static void gst_tensor_trainer_train_model (GstTensorTrainer * trainer); /** * @brief initialize the tensor_trainer's class @@ -158,8 +157,6 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass) gstelement_class->change_state = GST_DEBUG_FUNCPTR (gst_tensor_trainer_change_state); - /* Called when the element starts processing */ - trans_class->start = GST_DEBUG_FUNCPTR (gst_tensor_trainer_start); /* Called when the element stop processing */ trans_class->stop = GST_DEBUG_FUNCPTR (gst_tensor_trainer_stop); @@ -196,6 +193,14 @@ gst_tensor_trainer_class_init (GstTensorTrainerClass * klass) G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY | G_PARAM_STATIC_STRINGS)); + g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH, + g_param_spec_string ("model-save-path", "Model save path", + "Path to save the trained model in framework, if model-config " + "contains information about the save file, it is ignored", + DEFAULT_STR_PROP_VALUE, + G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY | + G_PARAM_STATIC_STRINGS)); + g_object_class_install_property (gobject_class, PROP_INPUT_DIM, g_param_spec_string ("input-dim", "Input dimension", "Input tensors dimension from inner array, up to 4 dimensions ?", "", @@ -277,6 +282,7 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer) GST_DEBUG (""); trainer->fw_name = g_strdup (DEFAULT_PROP_FRAMEWORK); trainer->model_config = g_strdup (DEFAULT_STR_PROP_VALUE); + trainer->model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE); trainer->input_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE); trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE); trainer->input_type = g_strdup (DEFAULT_STR_PROP_VALUE); @@ -284,7 +290,7 @@ gst_tensor_trainer_init (GstTensorTrainer * trainer) trainer->push_output = FALSE; trainer->fw = NULL; - trainer->fw_opened = 0; /* for test */ + trainer->fw_created = 0; /* for test */ trainer->configured = 0; trainer->input_configured = 0; trainer->output_configured = 0; @@ -304,14 +310,14 @@ gst_tensor_trainer_finalize (GObject * object) g_free (trainer->fw_name); g_free (trainer->model_config); + g_free (trainer->model_save_path); g_free (trainer->input_dimensions); g_free (trainer->output_dimensions); g_free (trainer->input_type); g_free (trainer->output_type); - GST_DEBUG ("trainer->fw_created=%d", trainer->fw_created); - if (trainer->fw_created) { - trainer->fw->close (&trainer->prop, &trainer->privateData); + if (trainer->fw_created && trainer->fw) { + trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData); } /* need to free prop data */ @@ -337,6 +343,9 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id, case PROP_MODEL_CONFIG: gst_tensor_trainer_set_prop_model_config_file_path (trainer, value); break; + case PROP_MODEL_SAVE_PATH: + gst_tensor_trainer_set_model_save_path (trainer, value); + break; case PROP_INPUT_DIM: gst_tensor_trainer_set_prop_dimension (trainer, value, TRUE); break; @@ -354,16 +363,16 @@ gst_tensor_trainer_set_property (GObject * object, guint prop_id, GST_INFO_OBJECT (trainer, "push output: %d", trainer->push_output); break; case PROP_NUM_INPUTS: - trainer->num_inputs = g_value_get_uint (value); + trainer->prop.num_inputs = g_value_get_uint (value); break; case PROP_NUM_LABELS: - trainer->num_labels = g_value_get_uint (value); + trainer->prop.num_labels = g_value_get_uint (value); break; case PROP_NUM_TRAINING_SAMPLES: - trainer->num_training_samples = g_value_get_uint (value); + trainer->prop.num_train_samples = g_value_get_uint (value); break; case PROP_NUM_VALIDATION_SAMPLES: - trainer->num_validation_samples = g_value_get_uint (value); + trainer->prop.num_valid_samples = g_value_get_uint (value); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); @@ -389,6 +398,9 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id, case PROP_MODEL_CONFIG: g_value_set_string (value, trainer->model_config); break; + case PROP_MODEL_SAVE_PATH: + g_value_set_string (value, trainer->model_save_path); + break; case PROP_INPUT_DIM: g_value_set_string (value, trainer->input_dimensions); break; @@ -405,16 +417,16 @@ gst_tensor_trainer_get_property (GObject * object, guint prop_id, g_value_set_boolean (value, trainer->push_output); break; case PROP_NUM_INPUTS: - g_value_set_uint (value, trainer->num_inputs); + g_value_set_uint (value, trainer->prop.num_inputs); break; case PROP_NUM_LABELS: - g_value_set_uint (value, trainer->num_labels); + g_value_set_uint (value, trainer->prop.num_labels); break; case PROP_NUM_TRAINING_SAMPLES: - g_value_set_uint (value, trainer->num_training_samples); + g_value_set_uint (value, trainer->prop.num_train_samples); break; case PROP_NUM_VALIDATION_SAMPLES: - g_value_set_uint (value, trainer->num_validation_samples); + g_value_set_uint (value, trainer->prop.num_valid_samples); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); @@ -439,11 +451,12 @@ gst_tensor_trainer_change_state (GstElement * element, case GST_STATE_CHANGE_READY_TO_PAUSED: GST_INFO_OBJECT (trainer, "READY_TO_PAUSED"); + gst_tensor_trainer_create_model (trainer); break; case GST_STATE_CHANGE_PAUSED_TO_PLAYING: GST_INFO_OBJECT (trainer, "PAUSED_TO_PLAYING"); - /* start or resume model train */ + gst_tensor_trainer_train_model (trainer); break; default: @@ -535,28 +548,6 @@ gst_tensor_trainer_src_event (GstBaseTransform * trans, GstEvent * event) return GST_BASE_TRANSFORM_CLASS (parent_class)->src_event (trans, event); } - -/** - * @brief Called when the element starts processing. optional vmethod of BaseTransform - */ -static gboolean -gst_tensor_trainer_start (GstBaseTransform * trans) -{ - GstTensorTrainer *trainer; - trainer = GST_TENSOR_TRAINER_CAST (trans); - - if (trainer->fw_name) - gst_tensor_trainer_find_framework (trainer, trainer->fw_name); - if (trainer->fw) { - /* calc input tensors size */ - gst_tensor_trainer_calc_input_tensors_size (trainer); - /* create, compile */ - gst_tensor_trainer_create_framework (trainer); - } - - return TRUE; -} - /** * @brief Called when the element stops processing. optional vmethod of BaseTransform */ @@ -625,15 +616,17 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf, /* Prepare tensor to invoke */ /* Check number of input tensors */ - if (mem_blocks != trainer->input_meta.num_tensors) { + GST_ERROR ("num_tensors: %d", trainer->prop.input_meta.num_tensors); + if (mem_blocks != trainer->prop.input_meta.num_tensors) { GST_ERROR_OBJECT (trainer, "Invalid memory blocks(%d)," "number of input tensors may be (%d)", mem_blocks, - trainer->input_meta.num_tensors); + trainer->prop.input_meta.num_tensors); goto error; } /* Check size of input tensors */ - for (i = 0; i < trainer->input_meta.num_tensors; i++) { + GST_ERROR ("num_tensors: %d", trainer->prop.input_meta.num_tensors); + for (i = 0; i < trainer->prop.input_meta.num_tensors; i++) { expected = gst_tensor_trainer_get_tensor_size (trainer, i, TRUE); if (expected != in_tensors[i].size) { GST_ERROR_OBJECT (trainer, "Invalid tensor size (%u'th memory chunk: %zd)" @@ -693,7 +686,7 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf, } /* Call the trainer-subplugin callback, invoke */ ret = - trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData, + trainer->fw->invoke (trainer->fw, &trainer->prop, trainer->privateData, invoke_tensors, out_tensors); /* Free out info */ @@ -706,7 +699,7 @@ gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf, } } else { ret = - trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData, + trainer->fw->invoke (trainer->fw, &trainer->prop, trainer->privateData, invoke_tensors, NULL); } @@ -780,7 +773,9 @@ gst_tensor_trainer_transform_caps (GstBaseTransform * trans, gst_tensors_config_init (&in_config); gst_tensors_config_init (&out_config); structure = gst_caps_get_structure (caps, 0); - gst_tensors_config_from_structure (&in_config, structure); + + if (!gst_tensors_config_from_structure (&in_config, structure)) + return NULL; /* Set framerate from input config */ out_config.rate_n = in_config.rate_n; @@ -790,7 +785,7 @@ gst_tensor_trainer_transform_caps (GstBaseTransform * trans, for trainer->input_meta and trainer->output_meta */ if (direction == GST_PAD_SRC) { if (trainer->input_configured) { - gst_tensors_info_copy (&out_config.info, &trainer->input_meta); + gst_tensors_info_copy (&out_config.info, &trainer->prop.input_meta); configured = TRUE; } } else { @@ -857,9 +852,11 @@ gst_tensor_trainer_set_caps (GstBaseTransform * trans, GstCaps * incaps, gst_tensors_config_init (&in_config); structure = gst_caps_get_structure (incaps, 0); - gst_tensors_config_from_structure (&in_config, structure); - if (!gst_tensors_info_is_equal (&in_config.info, &trainer->input_meta)) { + if (!gst_tensors_config_from_structure (&in_config, structure)) + return FALSE; + + if (!gst_tensors_info_is_equal (&in_config.info, &trainer->prop.input_meta)) { GST_ERROR_OBJECT (trainer, "The input tensors info is different between incaps and set property value. " "Please check pipeline's input tensor info and tensor_trainer's set property values" @@ -895,8 +892,23 @@ gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer * trainer, { g_free (trainer->model_config); trainer->model_config = g_value_dup_string (value); + trainer->prop.model_config = trainer->model_config; GST_INFO_OBJECT (trainer, "model configuration file path: %s", - trainer->model_config); + trainer->prop.model_config); +} + +/** + * @brief Handle "PROP_MODEL_SAVE_PATH" for set-property + */ +static void +gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer, + const GValue * value) +{ + g_free (trainer->model_save_path); + trainer->model_save_path = g_value_dup_string (value); + trainer->prop.model_save_path = trainer->model_save_path; + GST_INFO_OBJECT (trainer, "file path to save the model: %s", + trainer->prop.model_save_path); } /** @@ -921,7 +933,7 @@ gst_tensor_trainer_set_prop_dimension (GstTensorTrainer * trainer, } if (is_input) { - info = &trainer->input_meta; + info = &trainer->prop.input_meta; rank = trainer->input_ranks; trainer->input_configured = TRUE; g_free (trainer->input_dimensions); @@ -962,7 +974,6 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer, { GstTensorsInfo *info; guint num_types; - guint i; if ((is_input && trainer->inputtype_configured) || (!is_input && trainer->outputtype_configured)) { @@ -973,7 +984,7 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer, } if (is_input) { - info = &trainer->input_meta; + info = &trainer->prop.input_meta; trainer->inputtype_configured = TRUE; g_free (trainer->input_type); trainer->input_type = g_value_dup_string (value); @@ -988,9 +999,6 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer, num_types = gst_tensors_info_parse_types_string (info, g_value_get_string (value)); - for (i = 0; i < num_types; i++) { - trainer->tensors_inputtype[i] = info->info[i].type; - } info->num_tensors = num_types; } @@ -1001,30 +1009,14 @@ gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer, static void gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name) { - const GstTensorFilterFramework *fw = NULL; - gchar *str; + const GstTensorTrainerFramework *fw = NULL; + g_return_if_fail (name != NULL); g_return_if_fail (trainer != NULL); GST_INFO ("find framework: %s", name); - /* Need to add trainer type to subpluginType */ - fw = get_subplugin (NNS_SUBPLUGIN_FILTER, name); - - if (fw == NULL) { - /*Get sub-plugin priority from ini file and find sub-plugin */ - str = nnsconf_get_custom_value_string (name, "subplugin_priority"); - fw = gst_tensor_trainer_find_best_framework (str); - g_free (str); - } - - if (fw == NULL) { - /* Check the filter-alias from ini file */ - str = nnsconf_get_custom_value_string ("filter-aliases", name); - fw = gst_tensor_trainer_find_best_framework (str); - g_free (str); - } - + fw = get_subplugin (NNS_SUBPLUGIN_TRAINER, name); if (fw) { GST_INFO_OBJECT (trainer, "find framework %s:%p", trainer->fw_name, fw); trainer->fw = fw; @@ -1039,57 +1031,28 @@ gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name) static void gst_tensor_trainer_create_framework (GstTensorTrainer * trainer) { - if (!trainer->fw || trainer->fw_opened) { + g_return_if_fail (trainer != NULL); + + if (!trainer->fw || trainer->fw_created) { GST_ERROR_OBJECT (trainer, "fw is not opened(%d) or fw is null(%p)", - trainer->fw_opened, trainer->fw); + trainer->fw_created, trainer->fw); return; } /* For test */ - if (!trainer->fw->open) { /* fw->create, create model with configuration file */ + if (!trainer->fw->create) { /* fw->create, create model with configuration file */ GST_ERROR_OBJECT (trainer, "Could not find fw->create"); return; } /* Test code, need to create with load ini file */ GST_ERROR ("%p", trainer->privateData); - if (trainer->fw->open (&trainer->prop, &trainer->privateData) >= 0) + if (trainer->fw->create (trainer->fw, &trainer->prop, + &trainer->privateData) >= 0) trainer->fw_created = TRUE; GST_ERROR ("%p", trainer->privateData); } /** - * @brief Find sub-plugin trainer given the name list - */ -static const GstTensorFilterFramework * -gst_tensor_trainer_find_best_framework (const char *names) -{ - const GstTensorFilterFramework *fw = NULL; /* need to change to GstTensorTrainerFramework */ - gchar **subplugins; - guint i, len; - - if (names == NULL || names[0] == '\0') - return NULL; - - subplugins = g_strsplit_set (names, " ,;", -1); - - len = g_strv_length (subplugins); - - for (i = 0; i < len; i++) { - if (strlen (g_strstrip (subplugins[i])) == 0) - continue; - - fw = get_subplugin (NNS_SUBPLUGIN_FILTER, subplugins[i]); /* need to add trainer type to subpluginType */ - if (fw) { - GST_INFO ("i = %d found %s", i, subplugins[i]); - break; - } - } - g_strfreev (subplugins); - - return fw; -} - -/** * @brief Calculate tensor buffer size */ gsize @@ -1099,7 +1062,7 @@ gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer, guint index, GstTensorsInfo *info; if (is_input) - info = &trainer->input_meta; + info = &trainer->prop.input_meta; else info = &trainer->output_meta; @@ -1139,40 +1102,90 @@ gst_tensor_trainer_transform_size (GstBaseTransform * trans, } /** - * @brief Calculate the size of input tensors + * @brief Create model */ static void -gst_tensor_trainer_calc_input_tensors_size (GstTensorTrainer * trainer) +gst_tensor_trainer_create_model (GstTensorTrainer * trainer) { - GstTensorsInfo *info; - guint i, j, idx, size[NNS_TENSOR_SIZE_LIMIT] = { 1, }; - - const guint get_tensor_size[] = { - [_NNS_INT32] = 4, - [_NNS_UINT32] = 4, - [_NNS_INT16] = 2, - [_NNS_UINT16] = 2, - [_NNS_INT8] = 1, - [_NNS_UINT8] = 1, - [_NNS_FLOAT64] = 8, - [_NNS_FLOAT32] = 4, - [_NNS_INT64] = 8, - [_NNS_UINT64] = 8, - [_NNS_FLOAT16] = 2, - }; + g_return_if_fail (trainer != NULL); + GST_DEBUG_OBJECT (trainer, "called"); + + if (trainer->fw_name) + gst_tensor_trainer_find_framework (trainer, trainer->fw_name); + if (trainer->fw) { + /* model create and compile */ + gst_tensor_trainer_create_framework (trainer); + } +} +/** + * @brief Train model + */ +static void +gst_tensor_trainer_train_model (GstTensorTrainer * trainer) +{ + gint ret = -1; g_return_if_fail (trainer != NULL); + g_return_if_fail (trainer->fw != NULL); + g_return_if_fail (trainer->fw->train != NULL); - info = &trainer->input_meta; - idx = 0; - for (i = 0; i < info->num_tensors; i++) { - for (j = 0; j < NNS_TENSOR_RANK_LIMIT; j++) { - size[idx] *= info->info[i].dimension[j]; - } - trainer->tensors_inputsize[idx] = - size[idx] * get_tensor_size[trainer->tensors_inputtype[idx]]; - GST_DEBUG_OBJECT (trainer, "trainer->tensors_inputsize[%d]= %d", idx, - trainer->tensors_inputsize[idx]); - idx++; + GST_DEBUG_OBJECT (trainer, "called"); + ret = trainer->fw->train (trainer->fw, &trainer->prop, trainer->privateData); + if (ret != 0) { + GST_ERROR_OBJECT (trainer, "model train is failed"); + } +} + +/** + * @brief Trainer's sub-plugin should call this function to register itself. + * @param[in] ttsp tensor_trainer sub-plugin to be registered. + * @return TRUE if registered. FALSE is failed or duplicated. + */ +int +nnstreamer_trainer_probe (GstTensorTrainerFramework * ttsp) +{ + GstTensorTrainerFrameworkInfo info; + GstTensorTrainerProperties prop; + const char *name = NULL; + int ret = 0; + + g_return_val_if_fail (ttsp != NULL, 0); + + memset (&prop, 0, sizeof (GstTensorTrainerProperties)); + gst_tensors_info_init (&prop.input_meta); + + if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) { + GST_ERROR ("getFrameworkInfo() failed"); + return FALSE; } + name = info.name; + + return register_subplugin (NNS_SUBPLUGIN_TRAINER, name, ttsp); +} + +/** + * @brief Trainer's sub-plugin may call this to unregister itself. + * @param[in] ttsp tensor_trainer sub-plugin to be unregistered. + * @return TRUE if unregistered. FALSE is failed. + */ +int +nnstreamer_trainer_exit (GstTensorTrainerFramework * ttsp) +{ + GstTensorTrainerFrameworkInfo info; + GstTensorTrainerProperties prop; + const char *name = NULL; + int ret = 0; + + g_return_val_if_fail (ttsp != NULL, 0); + + memset (&prop, 0, sizeof (GstTensorTrainerProperties)); + gst_tensors_info_init (&prop.input_meta); + + if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) { + GST_ERROR ("getFrameworkInfo() failed"); + return FALSE; + } + name = info.name; + + return unregister_subplugin (NNS_SUBPLUGIN_TRAINER, name); } diff --git a/gst/nnstreamer/elements/gsttensor_trainer.h b/gst/nnstreamer/elements/gsttensor_trainer.h index ea90c84..594f56f 100644 --- a/gst/nnstreamer/elements/gsttensor_trainer.h +++ b/gst/nnstreamer/elements/gsttensor_trainer.h @@ -20,7 +20,7 @@ #include #include -#include +#include G_BEGIN_DECLS #define GST_TYPE_TENSOR_TRAINER \ @@ -37,7 +37,6 @@ G_BEGIN_DECLS typedef struct _GstTensorTrainer GstTensorTrainer; typedef struct _GstTensorTrainerClass GstTensorTrainerClass; - /** * @brief GstTensorTrainer data structure */ @@ -47,42 +46,29 @@ struct _GstTensorTrainer gchar *fw_name; gchar *model_config; + gchar *model_save_path; gchar *input_dimensions; gchar *output_dimensions; gchar *input_type; gchar *output_type; gboolean push_output; - unsigned int num_inputs; - unsigned int num_labels; - unsigned int num_training_samples; - unsigned int num_validation_samples; - - GstTensorsInfo input_meta; - GstTensorsInfo output_meta; gboolean configured; - int input_configured; int output_configured; int inputtype_configured; int outputtype_configured; unsigned int input_ranks[NNS_TENSOR_SIZE_LIMIT]; unsigned int output_ranks[NNS_TENSOR_SIZE_LIMIT]; - - tensor_type tensors_inputtype[NNS_TENSOR_SIZE_LIMIT]; - unsigned int tensors_inputsize[NNS_TENSOR_SIZE_LIMIT]; + GstTensorsInfo output_meta; /* draft */ - int fw_opened; - int fw_compiled; - int fw_fitted; int fw_created; int fw_stop; - int fw_paused; void *privateData; /**< NNFW plugin's private data is stored here */ - const GstTensorFilterFramework *fw; /* for test, need to make */ - GstTensorFilterProperties prop; /**< NNFW plugin's properties */ + const GstTensorTrainerFramework *fw; /* for test, need to make */ + GstTensorTrainerProperties prop; /**< NNFW plugin's properties */ }; /** diff --git a/jni/nnstreamer.mk b/jni/nnstreamer.mk index 0f5e9f9..adcb060 100644 --- a/jni/nnstreamer.mk +++ b/jni/nnstreamer.mk @@ -67,6 +67,7 @@ NNSTREAMER_PLUGINS_SRCS := \ $(NNSTREAMER_GST_HOME)/elements/gsttensor_sparseenc.c \ $(NNSTREAMER_GST_HOME)/elements/gsttensor_sparseutil.c \ $(NNSTREAMER_GST_HOME)/elements/gsttensor_split.c \ + $(NNSTREAMER_GST_HOME)/elements/gsttensor_trainer.c \ $(NNSTREAMER_GST_HOME)/elements/gsttensor_transform.c \ $(NNSTREAMER_GST_HOME)/tensor_filter/tensor_filter.c