[Example] init tflite model
authorjy1210.jung <jy1210.jung@samsung.com>
Fri, 20 Jul 2018 05:08:56 +0000 (14:08 +0900)
committer함명주/동작제어Lab(SR)/Principal Engineer/삼성전자 <myungjoo.ham@samsung.com>
Mon, 23 Jul 2018 01:17:18 +0000 (10:17 +0900)
prepare example of tflite filter,
1. add code to check tflite model file and load labels
2. prepare to handle buffer passed to sink element
3. change pipeline for tflite model (224x224 frame)

**Self evaluation:**
1. Build test: [*]Passed [ ]Failed [ ]Skipped
2. Run test: [ ]Passed [ ]Failed [* ]Skipped

nnstreamer_example/example_filter/nnstreamer_example_filter.c

index 09c3ccb..86b8ad0 100644 (file)
  * Pipeline :
  * v4l2src -- tee -- textoverlay -- videoconvert -- xvimagesink
  *                  |
- *                  --- tensor_converter -- tensor_filter -- tensor_sink
+ *                  --- videoscale -- tensor_converter -- tensor_filter -- tensor_sink
  *
  * This app displays video sink (xvimagesink).
+ *
  * 'tensor_filter' for image recognition.
+ * Download tflite moel 'Mobilenet_1.0_224_quant' from below link,
+ * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md#image-classification-quantized-models
+ *
  * 'tensor_sink' updates recognition result to display in textoverlay.
  *
  * Run example :
  * $ ./nnstreamer_example_filter
  */
 
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <glib.h>
 #include <gst/gst.h>
 
 /**
   }
 
 /**
+ * @brief Score threshold of tflite model.
+ */
+#define THRESHOLD (0.8)
+
+/**
+ * @brief Data structure for tflite model info.
+ */
+typedef struct
+{
+  gchar *model_path; /**< tflite model file path */
+  gchar *label_path; /**< label file path */
+  GList *labels; /**< list of loaded labels */
+} tflite_info_s;
+
+/**
  * @brief Data structure for app.
  */
 typedef struct
@@ -57,6 +84,9 @@ typedef struct
 
   gboolean running; /**< true when app is running */
   guint received; /**< received buffer count */
+  gint current_label; /**< current label index */
+  gint new_label; /**< new label index */
+  tflite_info_s tflite_info; /**< tflite model info */
 } AppData;
 
 /**
@@ -65,6 +95,128 @@ typedef struct
 static AppData g_app;
 
 /**
+ * @brief Free data in tflite info structure.
+ */
+static void
+_tflite_free_info (tflite_info_s * tflite_info)
+{
+  g_return_if_fail (tflite_info != NULL);
+
+  if (tflite_info->model_path) {
+    g_free (tflite_info->model_path);
+    tflite_info->model_path = NULL;
+  }
+
+  if (tflite_info->label_path) {
+    g_free (tflite_info->label_path);
+    tflite_info->label_path = NULL;
+  }
+
+  if (tflite_info->labels) {
+    g_list_free (tflite_info->labels);
+    tflite_info->labels = NULL;
+  }
+}
+
+/**
+ * @brief Check tflite model and load labels.
+ *
+ * This example uses 'Mobilenet_1.0_224_quant' for image classification.
+ */
+static gboolean
+_tflite_init_info (tflite_info_s * tflite_info, const gchar * path)
+{
+  const gchar tflite_model[] = "mobilenet_v1_1.0_224_quant.tflite";
+  const gchar tflite_label[] = "labels.txt";
+
+  FILE *fp;
+
+  g_return_val_if_fail (tflite_info != NULL, FALSE);
+
+  tflite_info->model_path = NULL;
+  tflite_info->label_path = NULL;
+  tflite_info->labels = NULL;
+
+  /** check model file exists */
+  tflite_info->model_path = g_strdup_printf ("%s/%s", path, tflite_model);
+
+  if (access (tflite_info->model_path, F_OK) != 0) {
+    _print_log ("cannot find tflite model [%s]", tflite_info->model_path);
+    return FALSE;
+  }
+
+  /** load labels */
+  tflite_info->label_path = g_strdup_printf ("%s/%s", path, tflite_label);
+
+  if ((fp = fopen (tflite_info->label_path, "r")) != NULL) {
+    char *line = NULL;
+    size_t len = 0;
+    ssize_t read;
+    gchar *label;
+
+    while ((read = getline (&line, &len, fp)) != -1) {
+      label = g_strdup ((gchar *) line);
+      tflite_info->labels = g_list_append (tflite_info->labels, label);
+    }
+
+    if (line) {
+      free (line);
+    }
+
+    fclose (fp);
+  } else {
+    _print_log ("cannot find tflite label [%s]", tflite_info->label_path);
+    return FALSE;
+  }
+
+  _print_log ("finished to load tflite label");
+  _print_log ("total labels %d", g_list_length (tflite_info->labels));
+  return TRUE;
+}
+
+/**
+ * @brief Get label string with given index.
+ */
+static gchar *
+_tflite_get_label (tflite_info_s * tflite_info, gint index)
+{
+  guint length;
+
+  g_return_val_if_fail (tflite_info != NULL, NULL);
+  g_return_val_if_fail (tflite_info->labels != NULL, NULL);
+
+  length = g_list_length (tflite_info->labels);
+  g_return_val_if_fail (index >= 0 && index < length, NULL);
+
+  return (gchar *) g_list_nth_data (tflite_info->labels, index);
+}
+
+/**
+ * @brief Get tflite label index.
+ * @param scores array of confidence score
+ * @param len array length
+ * @return -1 if failed to get max score index
+ */
+static gint
+_get_top_label_index (guint8 * scores, guint len)
+{
+  gint i;
+  gint index = -1;
+  guint8 max_score = 0;
+
+  g_return_val_if_fail (scores != NULL, -1);
+
+  for (i = 0; i < len; i++) {
+    if (scores[i] > 0 && scores[i] > max_score) {
+      index = i;
+      max_score = scores[i];
+    }
+  }
+
+  return index;
+}
+
+/**
  * @brief Free resources in app data.
  */
 static void
@@ -85,6 +237,8 @@ _free_app_data (void)
     gst_object_unref (g_app.pipeline);
     g_app.pipeline = NULL;
   }
+
+  _tflite_free_info (&g_app.tflite_info);
 }
 
 /**
@@ -154,13 +308,30 @@ _message_cb (GstBus * bus, GstMessage * message, gpointer user_data)
 static void
 _new_data_cb (GstElement * element, GstBuffer * buffer, gpointer user_data)
 {
+  /** print progress */
   g_app.received++;
   if (g_app.received % 150 == 0) {
     _print_log ("receiving new data [%d]", g_app.received);
   }
 
   if (g_app.running) {
-    /** @todo update textoverlay */
+    GstMemory *mem;
+    GstMapInfo info;
+    guint i;
+    guint num_mems;
+
+    num_mems = gst_buffer_n_memory (buffer);
+    for (i = 0; i < num_mems; i++) {
+      mem = gst_buffer_peek_memory (buffer, i);
+
+      if (gst_memory_map (mem, &info, GST_MAP_READ)) {
+        /** @todo handle data (info.data, info.size) */
+        _print_log ("received %zd", info.size);
+        g_app.new_label = _get_top_label_index (NULL, 0);
+
+        gst_memory_unmap (mem, &info);
+      }
+    }
   }
 }
 
@@ -199,17 +370,19 @@ static gboolean
 _timer_update_result_cb (gpointer user_data)
 {
   if (g_app.running) {
-    GstElement *textoverlay;
-    gchar *tensor_res;
+    GstElement *overlay;
+    gchar *label = NULL;
 
-    /** @todo update textoverlay */
-    tensor_res = g_strdup_printf ("total received %d", g_app.received);
+    if (g_app.current_label != g_app.new_label) {
+      g_app.current_label = g_app.new_label;
 
-    textoverlay = gst_bin_get_by_name (GST_BIN (g_app.pipeline), "tensor_res");
-    g_object_set (textoverlay, "text", tensor_res, NULL);
+      overlay = gst_bin_get_by_name (GST_BIN (g_app.pipeline), "tensor_res");
 
-    g_free (tensor_res);
-    gst_object_unref (textoverlay);
+      label = _tflite_get_label (&g_app.tflite_info, g_app.current_label);
+      g_object_set (overlay, "text", (label != NULL) ? label : "", NULL);
+
+      gst_object_unref (overlay);
+    }
   }
 
   return TRUE;
@@ -221,8 +394,9 @@ _timer_update_result_cb (gpointer user_data)
 int
 main (int argc, char **argv)
 {
-  const guint width = 640;
-  const guint height = 480;
+  const gchar tflite_model_path[] = "./tflite_model";
+  const guint width = 224;
+  const guint height = 224;
 
   gchar *str_pipeline;
   gulong handle_id;
@@ -234,6 +408,10 @@ main (int argc, char **argv)
   /** init app variable */
   g_app.running = FALSE;
   g_app.received = 0;
+  g_app.current_label = -1;
+  g_app.new_label = -1;
+
+  _check_cond_err (_tflite_init_info (&g_app.tflite_info, tflite_model_path));
 
   /** init gstreamer */
   gst_init (&argc, &argv);
@@ -243,15 +421,15 @@ main (int argc, char **argv)
   _check_cond_err (g_app.loop != NULL);
 
   /** init pipeline */
-  /** @todo add tensor filter */
   str_pipeline =
       g_strdup_printf
       ("v4l2src name=cam_src ! "
-      "video/x-raw,width=%d,height=%d,format=RGB,framerate=30/1 ! tee name=t_raw "
+      "video/x-raw,width=640,height=480,format=RGB,framerate=30/1 ! tee name=t_raw "
       "t_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! "
       "videoconvert ! xvimagesink name=img_tensor "
-      "t_raw. ! queue ! tensor_converter ! tensor_sink name=tensor_sink",
-      width, height);
+      "t_raw. ! queue ! videoscale ! video/x-raw,width=%d,height=%d ! tensor_converter ! "
+      "tensor_filter framework=tensorflow-lite model=%s ! tensor_sink name=tensor_sink",
+      width, height, g_app.tflite_info.model_path);
   g_app.pipeline = gst_parse_launch (str_pipeline, NULL);
   g_free (str_pipeline);
   _check_cond_err (g_app.pipeline != NULL);