[Example] update buffer data (filter example)
authorjy1210.jung <jy1210.jung@samsung.com>
Mon, 23 Jul 2018 07:47:08 +0000 (16:47 +0900)
committer함명주/동작제어Lab(SR)/Principal Engineer/삼성전자 <myungjoo.ham@samsung.com>
Tue, 24 Jul 2018 01:46:18 +0000 (10:46 +0900)
Add code to update label string from passed buffer.
1. update textoverlay with max score
2. update python example and document

Error to be fixed later: cannot link converter and filter.
To fix this issue, change tflite dim for caps negotication.

**Self evaluation:**
1. Build test: [*]Passed [ ]Failed [ ]Skipped
2. Run test: [ ]Passed [ ]Failed [*]Skipped

Signed-off-by: Jaeyun Jung <jy1210.jung@samsung.com>
Documentation/how-to-run-examples.md
nnstreamer_example/example_filter/nnstreamer_example_filter.c
nnstreamer_example/example_filter/nnstreamer_example_filter.py

index 521b0d8..d5faa2f 100644 (file)
@@ -32,26 +32,26 @@ $ cd ..
 ```
 v4l2src -- tee -- textoverlay -- videoconvert -- xvimagesink
             |
-            --- tensor_converter -- tensor_filter -- tensor_sink
+            --- videoscale -- tensor_converter -- tensor_filter -- tensor_sink
 ```
 
 NNStreamer example for image recognition.
 
 Displays video sink.
 
-1. 'tensor_filter' for image recognition.
+1. 'tensor_filter' for image recognition. (classification with 224x224 image).
 2. 'tensor_sink' updates recognition result to display in textoverlay.
 
 - Run example
 ```
-# python example
-$ cd nnstreamer_example/example_filter
-$ python nnstreamer_example_filter.py 
+$ cd build/nnstreamer_example/example_filter
+$ ./nnstreamer_example_filter 
 ```
 
 ```
-$ cd build/nnstreamer_example/example_filter
-$ ./nnstreamer_example_filter
+# for python example
+$ cd nnstreamer_example/example_filter
+$ python nnstreamer_example_filter.py 
 ```
 
 ## Example : video mixer
index 86b8ad0..adbbc6c 100644 (file)
   }
 
 /**
- * @brief Score threshold of tflite model.
- */
-#define THRESHOLD (0.8)
-
-/**
  * @brief Data structure for tflite model info.
  */
 typedef struct
@@ -71,6 +66,7 @@ typedef struct
   gchar *model_path; /**< tflite model file path */
   gchar *label_path; /**< label file path */
   GList *labels; /**< list of loaded labels */
+  guint total_labels; /**< count of labels */
 } tflite_info_s;
 
 /**
@@ -84,8 +80,8 @@ typedef struct
 
   gboolean running; /**< true when app is running */
   guint received; /**< received buffer count */
-  gint current_label; /**< current label index */
-  gint new_label; /**< new label index */
+  gint current_label_index; /**< current label index */
+  gint new_label_index; /**< new label index */
   tflite_info_s tflite_info; /**< tflite model info */
 } AppData;
 
@@ -169,8 +165,8 @@ _tflite_init_info (tflite_info_s * tflite_info, const gchar * path)
     return FALSE;
   }
 
-  _print_log ("finished to load tflite label");
-  _print_log ("total labels %d", g_list_length (tflite_info->labels));
+  tflite_info->total_labels = g_list_length (tflite_info->labels);
+  _print_log ("finished to load labels, total %d", tflite_info->total_labels);
   return TRUE;
 }
 
@@ -192,19 +188,23 @@ _tflite_get_label (tflite_info_s * tflite_info, gint index)
 }
 
 /**
- * @brief Get tflite label index.
- * @param scores array of confidence score
+ * @brief Update tflite label index with max score.
+ * @param scores array of scores
  * @param len array length
- * @return -1 if failed to get max score index
+ * @return None
  */
-static gint
-_get_top_label_index (guint8 * scores, guint len)
+static void
+_update_top_label_index (guint8 * scores, guint len)
 {
   gint i;
   gint index = -1;
   guint8 max_score = 0;
 
-  g_return_val_if_fail (scores != NULL, -1);
+  /** -1 if failed to get max score index */
+  g_app.new_label_index = -1;
+
+  g_return_if_fail (scores != NULL);
+  g_return_if_fail (len == g_app.tflite_info.total_labels);
 
   for (i = 0; i < len; i++) {
     if (scores[i] > 0 && scores[i] > max_score) {
@@ -213,7 +213,7 @@ _get_top_label_index (guint8 * scores, guint len)
     }
   }
 
-  return index;
+  g_app.new_label_index = index;
 }
 
 /**
@@ -325,9 +325,8 @@ _new_data_cb (GstElement * element, GstBuffer * buffer, gpointer user_data)
       mem = gst_buffer_peek_memory (buffer, i);
 
       if (gst_memory_map (mem, &info, GST_MAP_READ)) {
-        /** @todo handle data (info.data, info.size) */
-        _print_log ("received %zd", info.size);
-        g_app.new_label = _get_top_label_index (NULL, 0);
+        /** update label index with max score */
+        _update_top_label_index (info.data, (guint) info.size);
 
         gst_memory_unmap (mem, &info);
       }
@@ -373,12 +372,12 @@ _timer_update_result_cb (gpointer user_data)
     GstElement *overlay;
     gchar *label = NULL;
 
-    if (g_app.current_label != g_app.new_label) {
-      g_app.current_label = g_app.new_label;
+    if (g_app.current_label_index != g_app.new_label_index) {
+      g_app.current_label_index = g_app.new_label_index;
 
       overlay = gst_bin_get_by_name (GST_BIN (g_app.pipeline), "tensor_res");
 
-      label = _tflite_get_label (&g_app.tflite_info, g_app.current_label);
+      label = _tflite_get_label (&g_app.tflite_info, g_app.current_label_index);
       g_object_set (overlay, "text", (label != NULL) ? label : "", NULL);
 
       gst_object_unref (overlay);
@@ -395,6 +394,7 @@ int
 main (int argc, char **argv)
 {
   const gchar tflite_model_path[] = "./tflite_model";
+  /** 224x224 for tflite model */
   const guint width = 224;
   const guint height = 224;
 
@@ -408,8 +408,8 @@ main (int argc, char **argv)
   /** init app variable */
   g_app.running = FALSE;
   g_app.received = 0;
-  g_app.current_label = -1;
-  g_app.new_label = -1;
+  g_app.current_label_index = -1;
+  g_app.new_label_index = -1;
 
   _check_cond_err (_tflite_init_info (&g_app.tflite_info, tflite_model_path));
 
@@ -428,7 +428,8 @@ main (int argc, char **argv)
       "t_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! "
       "videoconvert ! xvimagesink name=img_tensor "
       "t_raw. ! queue ! videoscale ! video/x-raw,width=%d,height=%d ! tensor_converter ! "
-      "tensor_filter framework=tensorflow-lite model=%s ! tensor_sink name=tensor_sink",
+      "tensor_filter framework=tensorflow-lite model=%s ! "
+      "tensor_sink name=tensor_sink",
       width, height, g_app.tflite_info.model_path);
   g_app.pipeline = gst_parse_launch (str_pipeline, NULL);
   g_free (str_pipeline);
index 2bc1c8b..a3957f7 100644 (file)
@@ -13,10 +13,14 @@ NNStreamer example for image recognition.
 Pipeline :
 v4l2src -- tee -- textoverlay -- videoconvert -- xvimagesink
             |
-            --- tensor_converter -- tensor_filter -- tensor_sink
+            --- videoscale -- tensor_converter -- tensor_filter -- tensor_sink
 
 This app displays video sink (xvimagesink).
+
 'tensor_filter' for image recognition.
+Download tflite moel 'Mobilenet_1.0_224_quant' from below link,
+https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md#image-classification-quantized-models
+
 'tensor_sink' updates recognition result to display in textoverlay.
 
 Run example :
@@ -27,6 +31,8 @@ $ python nnstreamer_example_filter.py
 See https://lazka.github.io/pgi-docs/#Gst-1.0 for Gst API details.
 """
 
+import os
+import sys
 import gi
 gi.require_version('Gst', '1.0')
 from gi.repository import Gst, GObject
@@ -35,14 +41,21 @@ from gi.repository import Gst, GObject
 class NNStreamerExample:
     """NNStreamer example for image recognition."""
 
-    def __init__(self):
+    def __init__(self, argv=None):
         self.loop = None
         self.pipeline = None
         self.running = False
         self.received = 0
+        self.current_label_index = -1
+        self.new_label_index = -1
+        self.tflite_model = ''
+        self.tflite_labels = []
+
+        if not self.tflite_init():
+            raise Exception
 
         GObject.threads_init()
-        Gst.init(None)
+        Gst.init(argv)
 
     def run_example(self):
         """Init pipeline and run example.
@@ -53,23 +66,24 @@ class NNStreamerExample:
         self.loop = GObject.MainLoop()
 
         # init pipeline
-        # TODO: add tensor filter
         self.pipeline = Gst.parse_launch(
-            "v4l2src name=cam_src ! "
-            "video/x-raw,width=640,height=480,format=RGB,framerate=30/1 ! tee name=t_raw "
-            "t_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! "
-            "videoconvert ! xvimagesink name=img_tensor "
-            "t_raw. ! queue ! tensor_converter ! tensor_sink name=tensor_sink"
+            'v4l2src name=cam_src ! '
+            'video/x-raw,width=640,height=480,format=RGB,framerate=30/1 ! tee name=t_raw '
+            't_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! '
+            'videoconvert ! xvimagesink name=img_tensor '
+            't_raw. ! queue ! videoscale ! video/x-raw,width=224,height=224 ! tensor_converter ! '
+            f'tensor_filter framework=tensorflow-lite model={self.tflite_model} ! '
+            'tensor_sink name=tensor_sink'
         )
 
         # bus and message callback
         bus = self.pipeline.get_bus()
         bus.add_signal_watch()
-        bus.connect("message", self.on_bus_message)
+        bus.connect('message', self.on_bus_message)
 
         # tensor sink signal : new data callback
-        tensor_sink = self.pipeline.get_by_name("tensor_sink")
-        tensor_sink.connect("new-data", self.on_new_data)
+        tensor_sink = self.pipeline.get_by_name('tensor_sink')
+        tensor_sink.connect('new-data', self.on_new_data)
 
         # timer to update result
         GObject.timeout_add(500, self.on_timer_update_result)
@@ -79,7 +93,7 @@ class NNStreamerExample:
         self.running = True
 
         # set window title
-        self.set_window_title("img_tensor", "NNStreamer Example")
+        self.set_window_title('img_tensor', 'NNStreamer Example')
 
         # run main loop
         self.loop.run()
@@ -98,17 +112,17 @@ class NNStreamerExample:
         :return: None
         """
         if message.type == Gst.MessageType.EOS:
-            print("received eos message")
+            print('received eos message')
             self.loop.quit()
         elif message.type == Gst.MessageType.ERROR:
             error, debug = message.parse_error()
-            print(f"error {error} {debug}")
+            print(f'error {error} {debug}')
             self.loop.quit()
         elif message.type == Gst.MessageType.WARNING:
             error, debug = message.parse_warning()
-            print(f"warning {error} {debug}")
+            print(f'warning {error} {debug}')
         elif message.type == Gst.MessageType.STREAM_START:
-            print("received start message")
+            print('received start message')
 
     def on_new_data(self, sink, buffer):
         """Callback for tensor sink signal.
@@ -120,15 +134,15 @@ class NNStreamerExample:
         # print progress
         self.received += 1
         if (self.received % 150) == 0:
-            print(f"receiving new data [{self.received}]")
+            print(f'receiving new data [{self.received}]')
 
         if self.running:
-            # TODO: update textoverlay
             for idx in range(buffer.n_memory()):
                 mem = buffer.peek_memory(idx)
                 result, mapinfo = mem.map(Gst.MapFlags.READ)
                 if result:
-                    # print(f"received {mapinfo.size}")
+                    # update label index with max score
+                    self.update_top_label_index(mapinfo.data, mapinfo.size)
                     mem.unmap(mapinfo)
 
     def on_timer_update_result(self):
@@ -137,11 +151,12 @@ class NNStreamerExample:
         :return: True to ensure the timer continues
         """
         if self.running:
-            # TODO: update textoverlay
-            tensor_res = f"total received {self.received}"
-
-            textoverlay = self.pipeline.get_by_name("tensor_res")
-            textoverlay.set_property("text", tensor_res)
+            if self.current_label_index != self.new_label_index:
+                # update textoverlay
+                self.current_label_index = self.new_label_index
+                label = self.tflite_get_label(self.current_label_index)
+                textoverlay = self.pipeline.get_by_name('tensor_res')
+                textoverlay.set_property('text', label)
         return True
 
     def set_window_title(self, name, title):
@@ -153,13 +168,72 @@ class NNStreamerExample:
         """
         element = self.pipeline.get_by_name(name)
         if element is not None:
-            pad = element.get_static_pad("sink")
+            pad = element.get_static_pad('sink')
             if pad is not None:
                 tags = Gst.TagList.new_empty()
-                tags.add_value(Gst.TagMergeMode.APPEND, "title", title)
+                tags.add_value(Gst.TagMergeMode.APPEND, 'title', title)
                 pad.send_event(Gst.Event.new_tag(tags))
 
+    def tflite_init(self):
+        """Check tflite model and load labels.
+
+        :return: True if successfully initialized
+        """
+        tflite_model = 'mobilenet_v1_1.0_224_quant.tflite'
+        tflite_label = 'labels.txt'
+        current_folder = os.path.dirname(os.path.abspath(__file__))
+        model_folder = os.path.join(current_folder, 'tflite_model')
+
+        # check model file exists
+        self.tflite_model = os.path.join(model_folder, tflite_model)
+        if not os.path.exists(self.tflite_model):
+            print(f'cannot find tflite model [{self.tflite_model}]')
+            return False
+
+        # load labels
+        label_path = os.path.join(model_folder, tflite_label)
+        try:
+            with open(label_path, 'r') as label_file:
+                for line in label_file.readlines():
+                    self.tflite_labels.append(line)
+        except FileNotFoundError:
+            print(f'cannot find tflite label [{label_path}]')
+            return False
+
+        print(f'finished to load labels, total {len(self.tflite_labels)}')
+        return True
+
+    def tflite_get_label(self, index):
+        """Get label string with given index.
+
+        :param index: index for label
+        :return: label string
+        """
+        try:
+            label = self.tflite_labels[index]
+        except IndexError:
+            label = ''
+        return label
+
+    def update_top_label_index(self, data, data_size):
+        """Update tflite label index with max score.
+
+        :param data: array of scores
+        :param data_size: data size
+        :return: None
+        """
+        # -1 if failed to get max score index
+        self.new_label_index = -1
+
+        if data_size == len(self.tflite_labels):
+            scores = [data[i] for i in range(data_size)]
+            max_score = max(scores)
+            if max_score > 0:
+                self.new_label_index = scores.index(max_score)
+        else:
+            print(f'unexpected data size {data_size}')
+
 
-if __name__ == "__main__":
-    example = NNStreamerExample()
+if __name__ == '__main__':
+    example = NNStreamerExample(sys.argv[1:])
     example.run_example()