From 5cff3ca3b16fc55cfd162ad365f9cfa1102821f9 Mon Sep 17 00:00:00 2001 From: Yongjoo Ahn Date: Thu, 23 Feb 2023 18:27:10 +0900 Subject: [PATCH] [filter] Fix tflite filter to handle extra tensors - Change the configuration methods to handle extra tensors Signed-off-by: Yongjoo Ahn --- .../tensor_filter/tensor_filter_tensorflow_lite.cc | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_lite.cc b/ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_lite.cc index 1568b81..180e99c 100644 --- a/ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_lite.cc +++ b/ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_lite.cc @@ -656,21 +656,26 @@ TFLiteInterpreter::setTensorProp ( const std::vector &tensor_idx_list, GstTensorsInfo *tensorMeta) { tensorMeta->num_tensors = tensor_idx_list.size (); + if (tensorMeta->num_tensors > NNS_TENSOR_SIZE_LIMIT) { + ml_logi ("Create extra tensor info for the tflite model"); + gst_tensors_info_extra_create (tensorMeta); + } for (unsigned int i = 0; i < tensorMeta->num_tensors; ++i) { int idx = tensor_idx_list[i]; + GstTensorInfo *info = gst_tensors_info_get_nth_info (tensorMeta, i); - if (getTensorDim (idx, tensorMeta->info[i].dimension)) { + if (getTensorDim (idx, info->dimension)) { ml_loge ("failed to get the dimension of input tensors"); return -1; } - tensorMeta->info[i].type = getTensorType (interpreter->tensor (idx)->type); - tensorMeta->info[i].name = g_strdup (interpreter->tensor (idx)->name); + info->type = getTensorType (interpreter->tensor (idx)->type); + info->name = g_strdup (interpreter->tensor (idx)->name); #if (DBG) - gchar *dim_str = gst_tensor_get_dimension_string (tensorMeta->info[i].dimension); + gchar *dim_str = gst_tensor_get_dimension_string (info->dimension); ml_logi ("tensorMeta[%d] >> name[%s], type[%d], dim[%s]", i, - tensorMeta->info[i].name, tensorMeta->info[i].type, dim_str); + info->name, info->type, dim_str); g_free (dim_str); #endif } @@ -816,7 +821,8 @@ TFLiteInterpreter::cacheInOutTensorPtr () tensor_idx = interpreter->inputs ()[i]; tensor_ptr = interpreter->tensor (tensor_idx); - if (tensor_ptr->bytes != gst_tensor_info_get_size (&inputTensorMeta.info[i])) + GstTensorInfo *info = gst_tensors_info_get_nth_info (&inputTensorMeta, i); + if (tensor_ptr->bytes != gst_tensor_info_get_size (info)) goto fail_exit; inputTensorPtr.push_back (tensor_ptr); @@ -828,7 +834,8 @@ TFLiteInterpreter::cacheInOutTensorPtr () tensor_idx = interpreter->outputs ()[i]; tensor_ptr = interpreter->tensor (tensor_idx); - if (tensor_ptr->bytes != gst_tensor_info_get_size (&outputTensorMeta.info[i])) + GstTensorInfo *info = gst_tensors_info_get_nth_info (&outputTensorMeta, i); + if (tensor_ptr->bytes != gst_tensor_info_get_size (info)) goto fail_exit; outputTensorPtr.push_back (tensor_ptr); -- 2.7.4