[Filter/TF-Lite] apply memcpy for output tensors
authorHyoungjooAhn <hello.ahnn@gmail.com>
Tue, 16 Oct 2018 07:46:55 +0000 (16:46 +0900)
committerMyungJoo Ham <myungjoo.ham@gmail.com>
Wed, 17 Oct 2018 06:32:26 +0000 (15:32 +0900)
memcpy is applied for each output tensors of tflite. it should be just temporal solution and be updated with better way.

Signed-off-by: HyoungjooAhn <hello.ahnn@gmail.com>
gst/tensor_filter/tensor_filter_tensorflow_lite.c
gst/tensor_filter/tensor_filter_tensorflow_lite_core.cc

index 2aa6cf2..df82428 100644 (file)
@@ -150,7 +150,7 @@ tflite_close (const GstTensorFilter * filter, void **private_data)
 GstTensorFilterFramework NNS_support_tensorflow_lite = {
   .name = "tensorflow-lite",
   .allow_in_place = FALSE,      /** @todo: support this to optimize performance later. */
-  .allocate_in_invoke = TRUE,
+  .allocate_in_invoke = FALSE,
   .invoke_NN = tflite_invoke,
   .getInputDimension = tflite_getInputDim,
   .getOutputDimension = tflite_getOutputDim,
index fb756ac..1656dd7 100644 (file)
@@ -271,14 +271,6 @@ TFLiteCore::getOutputTensorDim (GstTensorsInfo * info)
 }
 
 /**
- * @breif A macro to reduce repeated switch-case lines
- */
-#define case4type(casename, type, element) \
-  case casename: \
-    inputTensor->data.element = (type *) input[i].data; \
-    break;
-
-/**
  * @brief      run the model with the input.
  * @param[in] input : The array of input tensors
  * @param[out]  output : The array of output tensors
@@ -298,27 +290,11 @@ TFLiteCore::invoke (const GstTensorMemory * input, GstTensorMemory * output)
   }
 
   for (int i = 0; i < getInputTensorSize (); i++) {
-    int in_tensor = interpreter->inputs ()[i];
-
-    TfLiteTensor *inputTensor = interpreter->tensor (in_tensor);
-
-    switch (inputTensorMeta.info[i].type) {
-      case4type (_NNS_FLOAT32, float, f);
-      case4type (_NNS_UINT8, uint8_t, uint8);
-      case4type (_NNS_INT32, int, i32);
-      case4type (_NNS_INT64, int64_t, i64);
-
-      case _NNS_UINT32:
-      case _NNS_INT16:
-      case _NNS_UINT16:
-      case _NNS_INT8:
-      case _NNS_FLOAT64:
-      case _NNS_UINT64:
-      default:
-        _print_log ("Not Supported Type");
-        return -3;
-        break;
-    }
+    int in_index = interpreter->inputs ()[i];
+    TfLiteTensor *input_tensor = interpreter->tensor (in_index);
+
+    g_assert (input_tensor->bytes == input[i].size);
+    input_tensor->data.raw = (char *) input[i].data;
   }
 
   if (interpreter->Invoke () != kTfLiteOk) {
@@ -327,11 +303,11 @@ TFLiteCore::invoke (const GstTensorMemory * input, GstTensorMemory * output)
   }
 
   for (int i = 0; i < outputTensorMeta.num_tensors; i++) {
-    if (outputTensorMeta.info[i].type == _NNS_FLOAT32) {
-      output[i].data = interpreter->typed_output_tensor < float >(i);
-    } else if (outputTensorMeta.info[i].type == _NNS_UINT8) {
-      output[i].data = interpreter->typed_output_tensor < uint8_t > (i);
-    }
+    int out_index = interpreter->outputs ()[i];
+    TfLiteTensor *output_tensor = interpreter->tensor (out_index);
+
+    g_assert (output_tensor->bytes == output[i].size);
+    memcpy (output[i].data, output_tensor->data.raw, output[i].size);
   }
 
 #if (DBG)