[filter] Fix tflite filter to handle extra tensors
authorYongjoo Ahn <yongjoo1.ahn@samsung.com>
Thu, 23 Feb 2023 09:27:10 +0000 (18:27 +0900)
committerSangjung Woo <again4you@gmail.com>
Thu, 20 Apr 2023 14:21:53 +0000 (23:21 +0900)
- Change the configuration methods to handle extra tensors

Signed-off-by: Yongjoo Ahn <yongjoo1.ahn@samsung.com>
ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_lite.cc

index 1568b81..180e99c 100644 (file)
@@ -656,21 +656,26 @@ TFLiteInterpreter::setTensorProp (
     const std::vector<int> &tensor_idx_list, GstTensorsInfo *tensorMeta)
 {
   tensorMeta->num_tensors = tensor_idx_list.size ();
+  if (tensorMeta->num_tensors > NNS_TENSOR_SIZE_LIMIT) {
+    ml_logi ("Create extra tensor info for the tflite model");
+    gst_tensors_info_extra_create (tensorMeta);
+  }
 
   for (unsigned int i = 0; i < tensorMeta->num_tensors; ++i) {
     int idx = tensor_idx_list[i];
+    GstTensorInfo *info = gst_tensors_info_get_nth_info (tensorMeta, i);
 
-    if (getTensorDim (idx, tensorMeta->info[i].dimension)) {
+    if (getTensorDim (idx, info->dimension)) {
       ml_loge ("failed to get the dimension of input tensors");
       return -1;
     }
-    tensorMeta->info[i].type = getTensorType (interpreter->tensor (idx)->type);
-    tensorMeta->info[i].name = g_strdup (interpreter->tensor (idx)->name);
+    info->type = getTensorType (interpreter->tensor (idx)->type);
+    info->name = g_strdup (interpreter->tensor (idx)->name);
 
 #if (DBG)
-    gchar *dim_str = gst_tensor_get_dimension_string (tensorMeta->info[i].dimension);
+    gchar *dim_str = gst_tensor_get_dimension_string (info->dimension);
     ml_logi ("tensorMeta[%d] >> name[%s], type[%d], dim[%s]", i,
-        tensorMeta->info[i].name, tensorMeta->info[i].type, dim_str);
+        info->name, info->type, dim_str);
     g_free (dim_str);
 #endif
   }
@@ -816,7 +821,8 @@ TFLiteInterpreter::cacheInOutTensorPtr ()
     tensor_idx = interpreter->inputs ()[i];
     tensor_ptr = interpreter->tensor (tensor_idx);
 
-    if (tensor_ptr->bytes != gst_tensor_info_get_size (&inputTensorMeta.info[i]))
+    GstTensorInfo *info = gst_tensors_info_get_nth_info (&inputTensorMeta, i);
+    if (tensor_ptr->bytes != gst_tensor_info_get_size (info))
       goto fail_exit;
 
     inputTensorPtr.push_back (tensor_ptr);
@@ -828,7 +834,8 @@ TFLiteInterpreter::cacheInOutTensorPtr ()
     tensor_idx = interpreter->outputs ()[i];
     tensor_ptr = interpreter->tensor (tensor_idx);
 
-    if (tensor_ptr->bytes != gst_tensor_info_get_size (&outputTensorMeta.info[i]))
+    GstTensorInfo *info = gst_tensors_info_get_nth_info (&outputTensorMeta, i);
+    if (tensor_ptr->bytes != gst_tensor_info_get_size (info))
       goto fail_exit;
 
     outputTensorPtr.push_back (tensor_ptr);