[Filter/Tensorflow] take the original rank of model internally
authorHyoungjooAhn <hello.ahnn@gmail.com>
Mon, 17 Dec 2018 10:26:01 +0000 (19:26 +0900)
committerjaeyun-jung <39614140+jaeyun-jung@users.noreply.github.com>
Wed, 19 Dec 2018 05:39:56 +0000 (14:39 +0900)
to run model, hold the original rank internally

Signed-off-by: HyoungjooAhn <hello.ahnn@gmail.com>
gst/tensor_filter/tensor_filter_tensorflow_core.cc
gst/tensor_filter/tensor_filter_tensorflow_core.h

index 79918ae..58e32d5 100644 (file)
@@ -276,16 +276,16 @@ TFCore::inputTensorValidation (std::vector<const NodeDef*> placeholders)
 
     gchar **str_dims;
     str_dims = g_strsplit (shape_description.c_str(), ",", -1);
-    uint len = g_strv_length (str_dims);
-    if (len > NNS_TENSOR_RANK_LIMIT){
+    inputTensorRank[i] = g_strv_length (str_dims);
+    if (inputTensorRank[i] > NNS_TENSOR_RANK_LIMIT){
       GST_ERROR ("The Rank of Input Tensor is not affordable. It's over our capacity.\n");
       return -5;
     }
-    for (int j = 0; j < len; j++) {
+    for (int j = 0; j < inputTensorRank[i]; j++) {
       if (!strcmp (str_dims[j], "?"))
         continue;
 
-      if (inputTensorMeta.info[i].dimension[len - j - 1] != atoi (str_dims[j])){
+      if (inputTensorMeta.info[i].dimension[inputTensorRank[i] - j - 1] != atoi (str_dims[j])){
         GST_ERROR ("Input Tensor is not valid: the dim of input tensor is different\n");
         return -4;
       }
@@ -363,7 +363,7 @@ TFCore::getOutputTensorDim (GstTensorsInfo * info)
 }
 
 #define copyInputWithType(type) \
-  inputTensor.flat<type>()(i) = ((type*)input->data)[i];
+  inputTensor.flat<type>()(j) = ((type*)input->data)[j];
 
 #define copyOutputWithType(type) \
   for(int j = 0; j < n; j++) \
@@ -378,62 +378,58 @@ TFCore::getOutputTensorDim (GstTensorsInfo * info)
 int
 TFCore::run (const GstTensorMemory * input, GstTensorMemory * output)
 {
-  /* TODO: Convert input -> inputTensor before run */
-
-  Tensor inputTensor(
-    getTensorTypeToTF(input->type),
-    TensorShape({
-      inputTensorMeta.info[0].dimension[3],
-      inputTensorMeta.info[0].dimension[2],
-      inputTensorMeta.info[0].dimension[1],
-      inputTensorMeta.info[0].dimension[0]
-    })
-  );
-  int len = input->size / tensor_element_size[input->type];
-
-  for (int i = 0; i < len; i++) {
-    switch (input->type) {
-      case _NNS_INT32:
-        copyInputWithType (int32);
-        break;
-      case _NNS_UINT32:
-        copyInputWithType (uint32);
-        break;
-      case _NNS_INT16:
-        copyInputWithType (int16);
-        break;
-      case _NNS_UINT16:
-        copyInputWithType (uint16);
-        break;
-      case _NNS_INT8:
-        copyInputWithType (int8);
-        break;
-      case _NNS_UINT8:
-        copyInputWithType (uint8);
-        break;
-      case _NNS_INT64:
-        copyInputWithType (int64);
-        break;
-      case _NNS_UINT64:
-        copyInputWithType (uint64);
-        break;
-      case _NNS_FLOAT32:
-        copyInputWithType (float);
-        break;
-      case _NNS_FLOAT64:
-        copyInputWithType (double);
-        break;
-      default:
-        /** @todo Support other types */
-        break;
-    }
-  }
-
   std::vector<std::pair<string, Tensor>> input_feeds;
   std::vector<string> output_tensor_names;
   std::vector<Tensor> outputs;
 
   for (int i = 0; i < inputTensorMeta.num_tensors; i++) {
+    TensorShape ts = TensorShape({});
+    for (int j = inputTensorRank[i] - 1; j >= 0; j--){
+      ts.AddDim(inputTensorMeta.info[i].dimension[j]);
+    }
+    Tensor inputTensor(
+      getTensorTypeToTF(input->type),
+      ts
+    );
+    int len = input->size / tensor_element_size[input->type];
+
+    for (int j = 0; j < len; j++) {
+      switch (input->type) {
+        case _NNS_INT32:
+          copyInputWithType (int32);
+          break;
+        case _NNS_UINT32:
+          copyInputWithType (uint32);
+          break;
+        case _NNS_INT16:
+          copyInputWithType (int16);
+          break;
+        case _NNS_UINT16:
+          copyInputWithType (uint16);
+          break;
+        case _NNS_INT8:
+          copyInputWithType (int8);
+          break;
+        case _NNS_UINT8:
+          copyInputWithType (uint8);
+          break;
+        case _NNS_INT64:
+          copyInputWithType (int64);
+          break;
+        case _NNS_UINT64:
+          copyInputWithType (uint64);
+          break;
+        case _NNS_FLOAT32:
+          copyInputWithType (float);
+          break;
+        case _NNS_FLOAT64:
+          copyInputWithType (double);
+          break;
+        default:
+          /** @todo Support other types */
+          break;
+      }
+    }
     input_feeds.push_back({inputTensorMeta.info[i].name, inputTensor});
   }
 
index 09e6393..342b04e 100644 (file)
@@ -83,6 +83,9 @@ private:
   GstTensorsInfo inputTensorMeta;  /**< The tensor info of input tensors */
   GstTensorsInfo outputTensorMeta;  /**< The tensor info of output tensors */
 
+  int inputTensorRank[NNS_TENSOR_SIZE_LIMIT];
+  int outputTensorRank[NNS_TENSOR_SIZE_LIMIT];
+
   Session * session;
 
   tensor_type getTensorTypeFromTF (DataType tfType);