From: HyoungjooAhn Date: Mon, 19 Nov 2018 06:50:11 +0000 (+0900) Subject: [Filter/TFLite] support dynamic model X-Git-Tag: v0.0.3~72 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ba5bce47f2061848dd5978cbd27922b1ed4c31c8;p=platform%2Fupstream%2Fnnstreamer.git [Filter/TFLite] support dynamic model User open tflite model when a model is already loaded, compare it with previous one and if it is different, make the object the new one with new model. Signed-off-by: HyoungjooAhn --- diff --git a/gst/tensor_filter/tensor_filter_tensorflow_lite.c b/gst/tensor_filter/tensor_filter_tensorflow_lite.c index e5d93c6..790417b 100644 --- a/gst/tensor_filter/tensor_filter_tensorflow_lite.c +++ b/gst/tensor_filter/tensor_filter_tensorflow_lite.c @@ -29,6 +29,7 @@ #include "tensor_filter.h" #include "tensor_filter_tensorflow_lite_core.h" #include +#include /** * @brief internal data of tensorflow lite @@ -39,6 +40,21 @@ struct _Tflite_data }; typedef struct _Tflite_data tflite_data; + +/** + * @brief Free privateData and move on. + */ +static void +tflite_close (const GstTensorFilter * filter, void **private_data) +{ + tflite_data *tf; + tf = *private_data; + tflite_core_delete (tf->tflite_private_data); + g_free (tf); + *private_data = NULL; + g_assert (filter->privateData == NULL); +} + /** * @brief Load tensorflow lite modelfile * @param filter : tensor_filter instance @@ -53,7 +69,13 @@ tflite_loadModelFile (const GstTensorFilter * filter, void **private_data) tflite_data *tf; if (filter->privateData != NULL) { /** @todo : Check the integrity of filter->data and filter->model_file, nnfw */ - return 1; + tf = *private_data; + if (strcmp (filter->prop.model_file, + tflite_core_getModelPath (tf->tflite_private_data))) { + tflite_close (filter, private_data); + } else { + return 1; + } } tf = g_new0 (tflite_data, 1); /** initialize tf Fill Zero! */ *private_data = tf; @@ -127,20 +149,6 @@ tflite_getOutputDim (const GstTensorFilter * filter, void **private_data, return ret; } -/** - * @brief Free privateData and move on. - */ -static void -tflite_close (const GstTensorFilter * filter, void **private_data) -{ - tflite_data *tf; - tf = *private_data; - tflite_core_delete (tf->tflite_private_data); - g_free (tf); - *private_data = NULL; - g_assert (filter->privateData == NULL); -} - GstTensorFilterFramework NNS_support_tensorflow_lite = { .name = "tensorflow-lite", .allow_in_place = FALSE, /** @todo: support this to optimize performance later. */ diff --git a/gst/tensor_filter/tensor_filter_tensorflow_lite_core.cc b/gst/tensor_filter/tensor_filter_tensorflow_lite_core.cc index 67dd44f..898eaa5 100644 --- a/gst/tensor_filter/tensor_filter_tensorflow_lite_core.cc +++ b/gst/tensor_filter/tensor_filter_tensorflow_lite_core.cc @@ -87,6 +87,16 @@ TFLiteCore::init() } /** + * @brief get the model path + * @return the model path. + */ +const char* +TFLiteCore::getModelPath() +{ + return model_path; +} + +/** * @brief get millisecond for time profiling. * @note it returns the millisecond. * @param t : the time struct. @@ -396,6 +406,7 @@ tflite_core_delete (void *tflite) /** * @brief initialize the object with tflite model + * @param tflite : the class object * @return 0 if OK. non-zero if error. */ int @@ -407,6 +418,18 @@ tflite_core_init (void *tflite) } /** + * @brief get the model path + * @param tflite : the class object + * @return the model path. + */ +const char * +tflite_core_getModelPath (void *tflite) +{ + TFLiteCore *c = (TFLiteCore *) tflite; + return c->getModelPath(); +} + +/** * @brief get the Dimension of Input Tensor of model * @param tflite : the class object * @param[out] info Structure for tensor info. diff --git a/gst/tensor_filter/tensor_filter_tensorflow_lite_core.h b/gst/tensor_filter/tensor_filter_tensorflow_lite_core.h index a4dd6f1..e3de21e 100644 --- a/gst/tensor_filter/tensor_filter_tensorflow_lite_core.h +++ b/gst/tensor_filter/tensor_filter_tensorflow_lite_core.h @@ -46,6 +46,7 @@ public: int init(); int loadModel (); + const char* getModelPath(); int setInputTensorProp (); int setOutputTensorProp (); int getInputTensorSize ();