[Filter/TF] apply c_api at tensor_filter_tensorflow
authorHyoung Joo Ahn <hello.ahn@samsung.com>
Wed, 24 Jul 2019 08:25:59 +0000 (17:25 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 1 Aug 2019 05:16:52 +0000 (14:16 +0900)
change the codes of tensorflow filter to use c_api only

Signed-off-by: Hyoung Joo Ahn <hello.ahn@samsung.com>
ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_core.cc
ext/nnstreamer/tensor_filter/tensor_filter_tensorflow_core.h

index a974c19..ea0f023 100644 (file)
@@ -34,7 +34,7 @@
 #define DBG FALSE
 #endif
 
-std::map <void*, Tensor> TFCore::outputTensorMap;
+std::map <void*, TF_Tensor*> TFCore::outputTensorMap;
 
 /**
  * @brief      TFCore creator
@@ -42,10 +42,9 @@ std::map <void*, Tensor> TFCore::outputTensorMap;
  * @note       the model of _model_path will be loaded simultaneously
  * @return     Nothing
  */
-TFCore::TFCore (const char * _model_path)
+TFCore::TFCore (const char *_model_path)
 {
   model_path = _model_path;
-  configured = false;
 
   gst_tensors_info_init (&inputTensorMeta);
   gst_tensors_info_init (&outputTensorMeta);
@@ -57,6 +56,22 @@ TFCore::TFCore (const char * _model_path)
  */
 TFCore::~TFCore ()
 {
+  TF_DeleteGraph (graph);
+
+  TF_Status* status = TF_NewStatus ();
+  TF_CloseSession (session, status);
+  if (TF_GetCode (status) != TF_OK) {
+    g_critical ("Error during session close!! - [Code: %d] %s",
+      TF_GetCode (status), TF_Message (status));
+  }
+
+  TF_DeleteSession (session, status);
+  if (TF_GetCode (status) != TF_OK) {
+    g_critical ("Error during session delete!! - [Code: %d] %s",
+      TF_GetCode (status), TF_Message (status));
+  }
+  TF_DeleteStatus (status);
+
   gst_tensors_info_free (&inputTensorMeta);
   gst_tensors_info_free (&outputTensorMeta);
 }
@@ -64,17 +79,32 @@ TFCore::~TFCore ()
 /**
  * @brief      initialize the object with tensorflow model
  * @return 0 if OK. non-zero if error.
+ *        -1 if the model is not loaded.
+ *        -2 if the initialization of input tensor is failed.
+ *        -3 if the initialization of output tensor is failed.
  */
 int
-TFCore::init (const GstTensorFilterProperties * prop,
-  const gboolean tf_mem_optmz)
+TFCore::init (const GstTensorFilterProperties * prop)
 {
+  if (loadModel ()) {
+    g_critical ("Failed to load model");
+    return -1;
+  }
+
+  if (validateTensor (&prop->input_meta, 1)) {
+    g_critical ("Failed to validate input tensor");
+    return -2;
+  }
+
+  if (validateTensor (&prop->output_meta, 0)) {
+    g_critical ("Failed to validate output tensor");
+    return -3;
+  }
+
   gst_tensors_info_copy (&inputTensorMeta, &prop->input_meta);
   gst_tensors_info_copy (&outputTensorMeta, &prop->output_meta);
 
-  mem_optmz = tf_mem_optmz;
-
-  return loadModel ();
+  return 0;
 }
 
 /**
@@ -88,14 +118,21 @@ TFCore::getModelPath ()
 }
 
 /**
+ * @brief      the definition of a deallocator method
+ */
+static void
+DeallocateBuffer (void* data, size_t t) {
+  std::free (data);
+}
+
+/**
  * @brief      load the tf model
  * @note       the model will be loaded
  * @return 0 if OK. non-zero if error.
- *        -1 if the modelfile is not valid(or not exist).
+ *        -1 if the pb file is not regular.
  *        -2 if the pb file is not loaded.
- *        -3 if the input properties are different with model.
- *        -4 if the Tensorflow session is not initialized.
- *        -5 if the Tensorflow session is not created.
+ *        -3 if importing graph is failed.
+ *        -4 if the Tensorflow session is not created.
  */
 int
 TFCore::loadModel ()
@@ -103,43 +140,57 @@ TFCore::loadModel ()
 #if (DBG)
   gint64 start_time = g_get_real_time ();
 #endif
+  gsize file_size;
+  gchar *content = nullptr;
+  GError *file_error = nullptr;
 
-  Status status;
-  GraphDef graph_def;
+  g_assert (model_path != nullptr);
 
   if (!g_file_test (model_path, G_FILE_TEST_IS_REGULAR)) {
     g_critical ("the file of model_path (%s) is not valid (not regular)\n", model_path);
     return -1;
   }
-  status = ReadBinaryProto (Env::Default (), model_path, &graph_def);
-  if (!status.ok()) {
-    g_critical ("Failed to read graph.\n%s", status.ToString().c_str());
+
+  if (!g_file_get_contents (model_path, &content, &file_size, &file_error)) {
+    g_critical ("Error reading model file!! - %s", file_error->message);
+    g_clear_error (&file_error);
     return -2;
   }
 
-  /* validate input tensor */
-  if (validateInputTensor (graph_def)) {
-    g_critical ("Input Tensor Information is not valid");
-    return -3;
-  }
+  TF_Buffer* buffer = TF_NewBuffer ();
+  buffer->data = content;
+  buffer->length = file_size;
+  buffer->data_deallocator = DeallocateBuffer;
 
-  /* get session */
-  status = NewSession (SessionOptions (), &session);
-  if (!status.ok()) {
-    g_critical ("Failed to init new session.\n%s", status.ToString().c_str());
-    return -4;
-  }
+  graph = TF_NewGraph ();
+  TF_Status* status = TF_NewStatus ();
+  TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions ();
+
+  TF_GraphImportGraphDef (graph, buffer, opts, status);
+  TF_DeleteImportGraphDefOptions (opts);
+  TF_DeleteBuffer (buffer);
 
-  status = session->Create (graph_def);
-  if (!status.ok()) {
-    g_critical ("Failed to create session.\n%s", status.ToString().c_str());
-    return -5;
+  if (TF_GetCode (status) != TF_OK) {
+    g_critical ("Error deleting graph!! - [Code: %d] %s",
+      TF_GetCode (status), TF_Message (status));
+    TF_DeleteStatus (status);
+    TF_DeleteGraph (graph);
+    return -3;
   }
 
-  /* prepare output tensor */
-  for (int i = 0; i < outputTensorMeta.num_tensors; ++i) {
-    output_tensor_names.push_back (outputTensorMeta.info[i].name);
+  g_assert (graph != nullptr);
+  TF_SessionOptions* options = TF_NewSessionOptions ();
+  session = TF_NewSession (graph, options, status);
+  TF_DeleteSessionOptions (options);
+
+  if (TF_GetCode (status) != TF_OK) {
+    g_critical ("Error creating Session!! - [Code: %d] %s",
+      TF_GetCode (status), TF_Message (status));
+    TF_DeleteStatus (status);
+    TF_DeleteGraph (graph);
+    return -4;
   }
+  TF_DeleteStatus (status);
 
 #if (DBG)
   gint64 stop_time = g_get_real_time ();
@@ -154,28 +205,28 @@ TFCore::loadModel ()
  * @return the enum of defined _NNS_TYPE
  */
 tensor_type
-TFCore::getTensorTypeFromTF (DataType tfType)
+TFCore::getTensorTypeFromTF (TF_DataType tfType)
 {
   switch (tfType) {
-    case DT_INT32:
+    case TF_INT32:
       return _NNS_INT32;
-    case DT_UINT32:
+    case TF_UINT32:
       return _NNS_UINT32;
-    case DT_INT16:
+    case TF_INT16:
       return _NNS_INT16;
-    case DT_UINT16:
+    case TF_UINT16:
       return _NNS_UINT16;
-    case DT_INT8:
+    case TF_INT8:
       return _NNS_INT8;
-    case DT_UINT8:
+    case TF_UINT8:
       return _NNS_UINT8;
-    case DT_INT64:
+    case TF_INT64:
       return _NNS_INT64;
-    case DT_UINT64:
+    case TF_UINT64:
       return _NNS_UINT64;
-    case DT_FLOAT:
+    case TF_FLOAT:
       return _NNS_FLOAT32;
-    case DT_DOUBLE:
+    case TF_DOUBLE:
       return _NNS_FLOAT64;
     default:
       /** @todo Support other types */
@@ -187,201 +238,114 @@ TFCore::getTensorTypeFromTF (DataType tfType)
 
 /**
  * @brief      return the data type of the tensor for Tensorflow
- * @param[in] tType    : the defined type of NNStreamer
- * @param[out] tf_type : the result type in TF_DataType
- * @return the result of type converting.
+ * @param tType        : the defined type of NNStreamer
+ * @return the enum of defined tensorflow::TF_DataType
  */
-gboolean
-TFCore::getTensorTypeToTF_Capi (tensor_type tType, TF_DataType * tf_type)
+TF_DataType
+TFCore::getTensorTypeToTF (tensor_type tType)
 {
   switch (tType) {
     case _NNS_INT32:
-      *tf_type = TF_INT32;
-      break;
+      return TF_INT32;
     case _NNS_UINT32:
-      *tf_type = TF_UINT32;
-      break;
+      return TF_UINT32;
     case _NNS_INT16:
-      *tf_type = TF_INT16;
-      break;
+      return TF_INT16;
     case _NNS_UINT16:
-      *tf_type = TF_UINT16;
-      break;
+      return TF_UINT16;
     case _NNS_INT8:
-      *tf_type = TF_INT8;
-      break;
+      return TF_INT8;
     case _NNS_UINT8:
-      *tf_type = TF_UINT8;
-      break;
+      return TF_UINT8;
     case _NNS_INT64:
-      *tf_type = TF_INT64;
-      break;
+      return TF_INT64;
     case _NNS_UINT64:
-      *tf_type = TF_UINT64;
-      break;
+      return TF_UINT64;
     case _NNS_FLOAT32:
-      *tf_type = TF_FLOAT;
-      break;
+      return TF_FLOAT;
     case _NNS_FLOAT64:
-      *tf_type = TF_DOUBLE;
-      break;
+      return TF_DOUBLE;
     default:
-      return FALSE;
+      /** @todo Support other types */
+      break;
   }
-  return TRUE;
+
+  return TF_VARIANT; // there is no flag for INVALID
 }
 
 /**
- * @brief      check the inserted information about input tensor with model
+ * @brief validate the src tensor info with graph
+ * @param      tensorInfo : the tensors' info which user inserted
+ * @param is_input : check is it input tensor or not to save the original shape
+ * @note Compare user inserted tensor information with information from loaded graphs
  * @return 0 if OK. non-zero if error.
- *        -1 if the number of input tensors is not matched.
- *        -2 if the name of input tensors is not matched.
- *        -3 if the type of input tensors is not matched.
- *        -4 if the dimension of input tensors is not matched.
- *        -5 if the rank of input tensors exceeds our capacity NNS_TENSOR_RANK_LIMIT.
+ *        -1 if getting rank of tensor is failed from the graph.
+ *        -2 if getting shape of tensor is failed from the graph.
  */
 int
-TFCore::validateInputTensor (const GraphDef &graph_def)
+TFCore::validateTensor (const GstTensorsInfo * tensorInfo, int is_input)
 {
-  std::vector <const NodeDef*> placeholders;
-  int length;
-
-  for (const NodeDef& node : graph_def.node ()) {
-    if (node.op () == "Placeholder") {
-      placeholders.push_back (&node);
-    }
-  }
-
-  if (placeholders.empty ()) {
-    GST_WARNING ("No inputs spotted.");
-    /* do nothing? */
-    return 0;
-  }
-
-  length = placeholders.size ();
-  GST_INFO ("Found possible inputs: %d", length);
-
-  if (inputTensorMeta.num_tensors != length) {
-    GST_ERROR ("Input Tensor is not valid: the number of input tensor is different\n");
-    return -1;
-  }
-
-  for (int i = 0; i < length; ++i) {
-    const NodeDef* node = placeholders[i];
-    string shape_description = "None";
-    if (node->attr ().count ("shape")) {
-      TensorShapeProto shape_proto = node->attr ().at ("shape").shape ();
-      Status shape_status = PartialTensorShape::IsValidShape (shape_proto);
-      if (shape_status.ok ()) {
-        shape_description = PartialTensorShape (shape_proto).DebugString ();
-      } else {
-        shape_description = shape_status.error_message ();
-      }
-    }
-    char chars[] = "[]";
-    for (unsigned int j = 0; j < strlen (chars); ++j)
-    {
-      shape_description.erase (
-        std::remove (
-          shape_description.begin (),
-          shape_description.end (),
-          chars[j]
-        ),
-        shape_description.end ()
-      );
+  for (int i = 0; i < tensorInfo->num_tensors; i++) {
+    // set the name of tensor
+    TF_Operation *op = TF_GraphOperationByName (graph, tensorInfo->info[i].name);
+
+    g_assert (op != nullptr);
+
+    const int num_outputs = TF_OperationNumOutputs (op);
+    g_assert (num_outputs == 1); /* an in/output tensor has only one output for now */
+
+    TF_Status *status = TF_NewStatus ();
+    const TF_Output output = {op, 0};
+    const TF_DataType type = TF_OperationOutputType (output);
+    const int num_dims = TF_GraphGetTensorNumDims (graph, output, status);
+    tf_tensor_info_s info_s;
+
+    if (TF_GetCode (status) != TF_OK) {
+      g_critical ("Error Tensor validation!! - [Code: %d] %s",
+        TF_GetCode (status), TF_Message (status));
+      TF_DeleteStatus (status);
+      return -1;
     }
 
-    DataType dtype = DT_INVALID;
-    char *tensor_name = inputTensorMeta.info[i].name;
-
-    if (node->attr ().count ("dtype")) {
-      dtype = node->attr ().at ("dtype").type ();
+    if (type != TF_STRING) {
+      g_assert (tensorInfo->info[i].type == getTensorTypeFromTF (type));
     }
+    info_s.type = type;
 
-    if (!tensor_name || !g_str_equal (tensor_name, node->name ().c_str ())) {
-      GST_ERROR ("Input Tensor is not valid: the name of input tensor is different\n");
-      return -2;
+    if (num_dims == -1) { /* in case of unknown shape */
+      info_s.rank = 0;
     }
-
-    if (inputTensorMeta.info[i].type != getTensorTypeFromTF (dtype)) {
-      /* consider the input data as bytes if tensor type is string */
-      if (dtype == DT_STRING) {
-        GST_WARNING ("Input type is string, ignore type comparision.");
-      } else {
-        GST_ERROR ("Input Tensor is not valid: the type of input tensor is different\n");
-        return -3;
+    else {
+      g_assert (num_dims > 0);
+      info_s.rank = num_dims;
+
+      std::vector<std::int64_t> dims (num_dims);
+
+      TF_GraphGetTensorShape (graph, output, dims.data (), num_dims, status);
+      if (TF_GetCode (status) != TF_OK) {
+        g_critical ("Error Tensor validation!! - [Code: %d] %s",
+          TF_GetCode (status), TF_Message (status));
+        TF_DeleteStatus (status);
+        return -2;
       }
-    }
-
-    gchar **str_dims;
-    guint rank, dim;
-    TensorShape ts = TensorShape ({});
-
-    str_dims = g_strsplit (shape_description.c_str (), ",", -1);
-    rank = g_strv_length (str_dims);
 
-    if (rank > NNS_TENSOR_RANK_LIMIT) {
-      GST_ERROR ("The Rank of Input Tensor is not affordable. It's over our capacity.\n");
-      g_strfreev (str_dims);
-      return -5;
-    }
-
-    for (int j = 0; j < rank; ++j) {
-      dim = inputTensorMeta.info[i].dimension[rank - j - 1];
-      ts.AddDim (dim);
-
-      if (g_str_equal (str_dims[j], "?"))
-        continue;
-
-      if (dim != (guint) g_ascii_strtoull (str_dims[j], NULL, 10)) {
-        GST_ERROR ("Input Tensor is not valid: the dim of input tensor is different\n");
-        g_strfreev (str_dims);
-        return -4;
+      // check the validity of dimension
+      for (int d = 0; d < num_dims; ++d) {
+        info_s.dims.push_back (
+          static_cast<int64_t> (tensorInfo->info[i].dimension[num_dims - d - 1])
+        );
+        if (dims[d] < 0) {
+          continue;
+        }
+        g_assert (tensorInfo->info[i].dimension[num_dims - d - 1] == dims[d]);
       }
     }
-    g_strfreev (str_dims);
-
-    /* add input tensor info */
-    tf_tensor_info_s info_s = { dtype, ts };
-    input_tensor_info.push_back (info_s);
-  }
-  return 0;
-}
-
-/**
- * @brief      check the inserted information about output tensor with model
- * @return 0 if OK. non-zero if error.
- *        -1 if the number of output tensors is not matched.
- *        -2 if the dimension of output tensors is not matched.
- *        -3 if the type of output tensors is not matched.
- */
-int
-TFCore::validateOutputTensor (const std::vector <Tensor> &outputs)
-{
-  if (outputTensorMeta.num_tensors != outputs.size()) {
-    GST_ERROR ("Invalid output meta: different size");
-    return -1;
-  }
-
-  for (int i = 0; i < outputTensorMeta.num_tensors; ++i) {
-    tensor_type otype;
-    gsize num;
-
-    otype = getTensorTypeFromTF (outputs[i].dtype());
-    num = gst_tensor_get_element_count (outputTensorMeta.info[i].dimension);
-
-    if (num != outputs[i].NumElements()) {
-      GST_ERROR ("Invalid output meta: different element count");
-      return -2;
-    }
-
-    if (outputTensorMeta.info[i].type != otype) {
-      GST_ERROR ("Invalid output meta: different type");
-      return -3;
+    if (is_input) {
+      /* save the original shape of the tensor */
+      input_tensor_info.push_back (info_s);
     }
+    TF_DeleteStatus (status);
   }
-
-  configured = true;
   return 0;
 }
 
@@ -412,41 +376,21 @@ TFCore::getOutputTensorDim (GstTensorsInfo * info)
 }
 
 /**
- * @brief      ring cache structure
+ * @brief      the definition of a deallocator method
  */
-class TFBuffer : public TensorBuffer {
- public:
-  void* data_;
-  size_t len_;
-
-#if (TF_MAJOR_VERSION == 1 && TF_MINOR_VERSION < 13)
-  explicit TFBuffer (void* data_ptr) : data_(data_ptr) {}
-  void* data () const override { return data_; }
-#elif (TF_MAJOR_VERSION == 1 && TF_MINOR_VERSION >= 13)
-  explicit TFBuffer (void* data_ptr) : TensorBuffer (data_ptr) {}
-#else
-#error This supports Tensorflow 1.x only.
-#endif
-
-  size_t size () const override { return len_; }
-  TensorBuffer* root_buffer () override { return this; }
-  void FillAllocationDescription (AllocationDescription* proto) const override {
-    int64 rb = size ();
-    proto->set_requested_bytes (rb);
-    proto->set_allocator_name (cpu_allocator ()->Name ());
-  }
-
-  /* Prevents input forwarding from mutating this buffer. */
-  bool OwnsMemory () const override { return false; }
-};
+static void
+DeallocateTensor (void* data, std::size_t, void*) {
+  /* do nothing, the data will be free at the last of pipeline */
+  return;
+}
 
 /**
  * @brief      run the model with the input.
  * @param[in] input : The array of input tensors
  * @param[out]  output : The array of output tensors
  * @return 0 if OK. non-zero if error.
- *         -1 if the model does not work properly.
- *         -2 if the output properties are different with model.
+ *        -1 if encoding STRING is failed.
+ *        -2 if running session is failed.
  */
 int
 TFCore::run (const GstTensorMemory * input, GstTensorMemory * output)
@@ -454,78 +398,109 @@ TFCore::run (const GstTensorMemory * input, GstTensorMemory * output)
 #if (DBG)
   gint64 start_time = g_get_real_time ();
 #endif
-
-  std::vector <std::pair <string, Tensor>> input_feeds;
-  std::vector <Tensor> outputs;
-
-  for (int i = 0; i < inputTensorMeta.num_tensors; ++i) {
-    Tensor in;
-
-    /* If the datatype is STRING, it should be handled in specific process */
-    if (input_tensor_info[i].type == DT_STRING) {
-      in = Tensor (input_tensor_info[i].type, input_tensor_info[i].shape);
-      in.scalar<string>()() = string ((char *) input[i].data, input[i].size);
-    } else {
-      if (mem_optmz) {
-        TFBuffer *buf;
-        TF_DataType dataType;
-
-        if (!getTensorTypeToTF_Capi (input[i].type, &dataType)){
-          g_critical ("This data type is not valid: %d", input[i].type);
-          return -1;
-        }
-
-        /* this input tensor should be UNREF */
-        buf = new TFBuffer (input[i].data);
-        buf->len_ = input[i].size;
-
-        in = TensorCApi::MakeTensor (
-          dataType,
-          input_tensor_info[i].shape,
-          buf
-        );
-
-        buf->Unref();
-
-        if (!in.IsAligned ()) {
-          g_critical ("the input tensor %s is not aligned", inputTensorMeta.info[i].name);
-          return -2;
-        }
-      } else {
-        in = Tensor (input_tensor_info[i].type, input_tensor_info[i].shape);
-        /* copy data */
-        std::copy_n ((char *) input[i].data, input[i].size,
-            const_cast<char *>(in.tensor_data().data()));
+  std::vector<TF_Output> input_ops;
+  std::vector<TF_Tensor*> input_tensors;
+  std::vector<TF_Output> output_ops;
+  std::vector<TF_Tensor*> output_tensors;
+  TF_Status* status = TF_NewStatus ();
+  char *input_encoded = nullptr;
+
+  // create input tensor for the graph from `input`
+  for (int i = 0; i < inputTensorMeta.num_tensors; i++) {
+    TF_Tensor* in_tensor = nullptr;
+    TF_Output input_op = {
+      TF_GraphOperationByName (graph, inputTensorMeta.info[i].name), 0
+      };
+    g_assert (input_op.oper != nullptr);
+    input_ops.push_back (input_op);
+
+    if (input_tensor_info[i].type == TF_STRING){
+      size_t encoded_size = TF_StringEncodedSize (input[i].size);
+      size_t total_size = 8 + encoded_size;
+      input_encoded = (char*) malloc (total_size);
+      for (int j =0; j < 8; ++j) {
+          input_encoded[j] = 0;
       }
+      TF_StringEncode (
+        (char *)input[i].data,
+        input[i].size,
+        input_encoded+8,
+        encoded_size,
+        status); // fills the rest of tensor data
+      if (TF_GetCode (status) != TF_OK) {
+        g_critical ("Error String Encoding!! - [Code: %d] %s",
+          TF_GetCode (status), TF_Message (status));
+        TF_DeleteStatus (status);
+        return -1;
+      }
+      in_tensor = TF_NewTensor (
+        input_tensor_info[i].type,
+        NULL,
+        0,
+        input_encoded,
+        total_size,
+        &DeallocateTensor,
+        nullptr);
+    }
+    else {
+      in_tensor = TF_NewTensor (
+          input_tensor_info[i].type,
+          input_tensor_info[i].dims.data (),
+          input_tensor_info[i].rank,
+          input[i].data,
+          input[i].size,
+          DeallocateTensor, /* no deallocator */
+          nullptr);
     }
-    input_feeds.push_back ({inputTensorMeta.info[i].name, in});
+    input_tensors.push_back (in_tensor);
   }
 
-  Status run_status =
-      session->Run (input_feeds, output_tensor_names, {}, &outputs);
+  // create output tensor for the graph from `output`
+  for (int i = 0; i < outputTensorMeta.num_tensors; i++) {
+    TF_Output output_op = {
+      TF_GraphOperationByName (graph, outputTensorMeta.info[i].name), 0
+      };
+    g_assert (output_op.oper != nullptr);
+    output_ops.push_back (output_op);
 
-  if (!run_status.ok()) {
-    g_critical ("Failed to run model: %s\n", run_status.ToString().c_str());
-    return -1;
+    TF_Tensor* out_tensor = nullptr;
+    output_tensors.push_back (out_tensor);
+  }
+
+  TF_SessionRun (session,
+                nullptr,
+                input_ops.data (), input_tensors.data (),
+                inputTensorMeta.num_tensors,
+                output_ops.data (), output_tensors.data (),
+                outputTensorMeta.num_tensors,
+                nullptr, 0,
+                nullptr,
+                status
+                );
+
+  for (int i = 0; i < inputTensorMeta.num_tensors; i++) {
+    TF_DeleteTensor (input_tensors[i]);
+    if (input_tensor_info[i].type == TF_STRING && input_encoded){
+      free (input_encoded);
+    }
   }
 
-  /* validate output tensor once */
-  if (!configured && validateOutputTensor (outputs)) {
-    g_critical ("Output Tensor Information is not valid");
+  if (TF_GetCode (status) != TF_OK) {
+    g_critical ("Error Running Session!! - [Code: %d] %s",
+      TF_GetCode (status), TF_Message (status));
+    TF_DeleteStatus (status);
     return -2;
   }
 
-  for (int i = 0; i < outputTensorMeta.num_tensors; ++i) {
-    /**
-     * @todo support DT_STRING output tensor
-     */
-    output[i].data = const_cast<char *>(outputs[i].tensor_data().data());
-    outputTensorMap.insert (std::make_pair (output[i].data, outputs[i]));
+  for (int i = 0; i < outputTensorMeta.num_tensors; i++) {
+    output[i].data = TF_TensorData (output_tensors[i]);
+    outputTensorMap.insert (std::make_pair (output[i].data, output_tensors[i]));
   }
+  TF_DeleteStatus (status);
 
 #if (DBG)
   gint64 stop_time = g_get_real_time ();
-  g_message ("Invoke() is finished: %" G_GINT64_FORMAT,
+  g_message ("Run() is finished: %" G_GINT64_FORMAT,
       (stop_time - start_time));
 #endif
 
@@ -533,7 +508,7 @@ TFCore::run (const GstTensorMemory * input, GstTensorMemory * output)
 }
 
 void *
-tf_core_new (const char * _model_path)
+tf_core_new (const char *_model_path)
 {
   return new TFCore (_model_path);
 }
@@ -560,7 +535,7 @@ tf_core_init (void * tf, const GstTensorFilterProperties * prop,
   const gboolean tf_mem_optmz)
 {
   TFCore *c = (TFCore *) tf;
-  return c->init (prop, tf_mem_optmz);
+  return c->init (prop);
 }
 
 /**
@@ -622,5 +597,6 @@ tf_core_run (void * tf, const GstTensorMemory * input, GstTensorMemory * output)
 void
 tf_core_destroyNotify (void * data)
 {
+  TF_DeleteTensor ( (TFCore::outputTensorMap.find (data))->second);
   TFCore::outputTensorMap.erase (data);
 }
index 1273b4a..16116d9 100644 (file)
 #include <fstream>
 #include <algorithm>
 #include <vector>
+#include <map>
 
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wredundant-decls"
 #include <tensorflow/c/c_api.h>
-#include <tensorflow/c/c_api_internal.h>
-#include <tensorflow/core/public/session.h>
-#include <tensorflow/core/public/version.h>
-#pragma GCC diagnostic pop
-
-using namespace tensorflow;
 
 /**
  * @brief      Internal data structure for tensorflow
  */
 typedef struct
 {
-  DataType type;
-  TensorShape shape;
+  TF_DataType type;
+  int rank;
+  std::vector < std::int64_t > dims;
 } tf_tensor_info_s;
 
 /**
@@ -68,33 +62,31 @@ public:
   TFCore (const char * _model_path);
    ~TFCore ();
 
-  int init (const GstTensorFilterProperties * prop, const gboolean tf_mem_optmz);
+  int init (const GstTensorFilterProperties * prop);
   int loadModel ();
-  const char* getModelPath();
+  const char *getModelPath ();
+
   int getInputTensorDim (GstTensorsInfo * info);
   int getOutputTensorDim (GstTensorsInfo * info);
   int run (const GstTensorMemory * input, GstTensorMemory * output);
 
-  static std::map <void*, Tensor> outputTensorMap;
+  static std::map < void *, TF_Tensor * >outputTensorMap;
 
 private:
 
   const char *model_path;
-  gboolean mem_optmz;
 
-  GstTensorsInfo inputTensorMeta;  /**< The tensor info of input tensors */
-  GstTensorsInfo outputTensorMeta;  /**< The tensor info of output tensors */
+  GstTensorsInfo inputTensorMeta;  /**< The tensor info of input tensors from user input */
+  GstTensorsInfo outputTensorMeta;  /**< The tensor info of output tensors from user input */
 
-  std::vector <tf_tensor_info_s> input_tensor_info;
-  std::vector <string> output_tensor_names;
-  bool configured; /**< True if the model is successfully loaded */
+  std::vector < tf_tensor_info_s > input_tensor_info; /* hold information for TF */
 
-  Session *session;
+  TF_Graph *graph;
+  TF_Session *session;
 
-  tensor_type getTensorTypeFromTF (DataType tfType);
-  gboolean getTensorTypeToTF_Capi (tensor_type tType, TF_DataType * tf_type);
-  int validateInputTensor (const GraphDef &graph_def);
-  int validateOutputTensor (const std::vector <Tensor> &outputs);
+  tensor_type getTensorTypeFromTF (TF_DataType tfType);
+  TF_DataType getTensorTypeToTF (tensor_type tType);
+  int validateTensor (const GstTensorsInfo * tensorInfo, int is_input);
 };
 
 /**
@@ -104,7 +96,7 @@ extern "C"
 {
 #endif
 
-  void *tf_core_new (const char *_model_path);
+  void *tf_core_new (const char * _model_path);
   void tf_core_delete (void * tf);
   int tf_core_init (void * tf, const GstTensorFilterProperties * prop,
       const gboolean tf_mem_optmz);
@@ -119,4 +111,4 @@ extern "C"
 }
 #endif
 
-#endif /* TENSOR_FILTER_TENSORFLOW_CORE_H */
+#endif                          /* TENSOR_FILTER_TENSORFLOW_CORE_H */