[Filter/Tensorflow] fix bug of testcase accepted/tizen/unified/20190326.024547 submit/tizen/20190325.074521
authorHyoung Joo Ahn <hello.ahn@samsung.com>
Mon, 25 Mar 2019 05:09:20 +0000 (14:09 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Mon, 25 Mar 2019 06:50:27 +0000 (15:50 +0900)
Since the STRING input tensor should be processed differently, the logic was fixed.

Signed-off-by: Hyoung Joo Ahn <hello.ahn@samsung.com>
ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_core.cc
ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_core.h

index d697e93..feae0fa 100644 (file)
@@ -187,39 +187,48 @@ TFCore::getTensorTypeFromTF (DataType tfType)
 
 /**
  * @brief      return the data type of the tensor for Tensorflow
- * @param tType        : the defined type of NNStreamer
- * @return the enum of defined tensorflow::DataType
+ * @param[in] tType    : the defined type of NNStreamer
+ * @param[out] tf_type : the result type in TF_DataType
+ * @return the result of type converting.
  */
-TF_DataType
-TFCore::getTensorTypeToTF_Capi (tensor_type tType)
+gboolean
+TFCore::getTensorTypeToTF_Capi (tensor_type tType, TF_DataType * tf_type)
 {
   switch (tType) {
     case _NNS_INT32:
-      return TF_INT32;
+      *tf_type = TF_INT32;
+      break;
     case _NNS_UINT32:
-      return TF_UINT32;
+      *tf_type = TF_UINT32;
+      break;
     case _NNS_INT16:
-      return TF_INT16;
+      *tf_type = TF_INT16;
+      break;
     case _NNS_UINT16:
-      return TF_UINT16;
+      *tf_type = TF_UINT16;
+      break;
     case _NNS_INT8:
-      return TF_INT8;
+      *tf_type = TF_INT8;
+      break;
     case _NNS_UINT8:
-      return TF_UINT8;
+      *tf_type = TF_UINT8;
+      break;
     case _NNS_INT64:
-      return TF_INT64;
+      *tf_type = TF_INT64;
+      break;
     case _NNS_UINT64:
-      return TF_UINT64;
+      *tf_type = TF_UINT64;
+      break;
     case _NNS_FLOAT32:
-      return TF_FLOAT;
+      *tf_type = TF_FLOAT;
+      break;
     case _NNS_FLOAT64:
-      return TF_DOUBLE;
-    default:
-      /** @todo Support other types */
+      *tf_type = TF_DOUBLE;
       break;
+    default:
+      return FALSE;
   }
-  /* Since there is no INVALID, TF_RESOURCE is used to detect invalid datatype temporally */
-  return TF_RESOURCE;
+  return TRUE;
 }
 
 /**
@@ -443,43 +452,40 @@ TFCore::run (const GstTensorMemory * input, GstTensorMemory * output)
   Tensor in;
 
   for (int i = 0; i < inputTensorMeta.num_tensors; ++i) {
-
-    if (mem_optmz) {
-      TFBuffer *buf = new TFBuffer;
-      buf->len_ = input[i].size;
-      buf->data_ = input[i].data;
-
-      TF_DataType dataType = getTensorTypeToTF_Capi (input[i].type);
-
-      if (dataType == TF_RESOURCE){
-        g_critical ("This data type is not valid: %d", input[i].type);
-        buf->Unref();
-        return -1;
-      }
-
-      in = TensorCApi::MakeTensor (
-        dataType,
-        input_tensor_info[i].shape,
-        buf
-      );
-      if (!in.IsAligned ()) {
-        g_critical ("the input tensor %s is not aligned", inputTensorMeta.info[i].name);
-        buf->Unref();
-        return -2;
-      }
-    }
-    else {
+    /* If the datatype is STRING, it should be handled in specific process */
+    if (input_tensor_info[i].type == DT_STRING) {
       in = Tensor (input_tensor_info[i].type, input_tensor_info[i].shape);
-
-      /* copy data */
-      if (input_tensor_info[i].type == DT_STRING) {
-        in.scalar<string>()() = string ((char *) input[i].data, input[i].size);
+      in.scalar<string>()() = string ((char *) input[i].data, input[i].size);
+    } else {
+      if (mem_optmz) {
+        TFBuffer *buf = new TFBuffer;
+        buf->len_ = input[i].size;
+        buf->data_ = input[i].data;
+
+        TF_DataType dataType;
+        if (!getTensorTypeToTF_Capi (input[i].type, &dataType)){
+          g_critical ("This data type is not valid: %d", input[i].type);
+          buf->Unref();
+          return -1;
+        }
+        /* this input tensor should be UNREF */
+        in = TensorCApi::MakeTensor (
+          dataType,
+          input_tensor_info[i].shape,
+          buf
+        );
+        if (!in.IsAligned ()) {
+          g_critical ("the input tensor %s is not aligned", inputTensorMeta.info[i].name);
+          buf->Unref();
+          return -2;
+        }
       } else {
+        in = Tensor (input_tensor_info[i].type, input_tensor_info[i].shape);
+        /* copy data */
         std::copy_n ((char *) input[i].data, input[i].size,
             const_cast<char *>(in.tensor_data().data()));
       }
     }
-
     input_feeds.push_back ({inputTensorMeta.info[i].name, in});
   }
 
@@ -487,8 +493,12 @@ TFCore::run (const GstTensorMemory * input, GstTensorMemory * output)
       session->Run (input_feeds, output_tensor_names, {}, &outputs);
 
   if (mem_optmz) {
-    TensorBuffer *buf = TensorCApi::Buffer (in);
-    buf->Unref();
+    for (int i = 0; i < inputTensorMeta.num_tensors; ++i) {
+      if (input_feeds[i].second.dtype () != DT_STRING){
+        TensorBuffer *buf = TensorCApi::Buffer (input_feeds[i].second);
+        buf->Unref();
+      }
+    }
   }
 
   if (!run_status.ok()) {
index ce08017..eebc38d 100644 (file)
@@ -90,7 +90,7 @@ private:
   Session *session;
 
   tensor_type getTensorTypeFromTF (DataType tfType);
-  TF_DataType getTensorTypeToTF_Capi (tensor_type tType);
+  gboolean getTensorTypeToTF_Capi (tensor_type tType, TF_DataType * tf_type);
   int validateInputTensor (const GraphDef &graph_def);
   int validateOutputTensor (const std::vector <Tensor> &outputs);
 };