* Pipeline :
* v4l2src -- tee -- textoverlay -- videoconvert -- xvimagesink
* |
- * --- tensor_converter -- tensor_filter -- tensor_sink
+ * --- videoscale -- tensor_converter -- tensor_filter -- tensor_sink
*
* This app displays video sink (xvimagesink).
+ *
* 'tensor_filter' for image recognition.
+ * Download tflite moel 'Mobilenet_1.0_224_quant' from below link,
+ * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md#image-classification-quantized-models
+ *
* 'tensor_sink' updates recognition result to display in textoverlay.
*
* Run example :
* $ ./nnstreamer_example_filter
*/
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <glib.h>
#include <gst/gst.h>
/**
}
/**
+ * @brief Score threshold of tflite model.
+ */
+#define THRESHOLD (0.8)
+
+/**
+ * @brief Data structure for tflite model info.
+ */
+typedef struct
+{
+ gchar *model_path; /**< tflite model file path */
+ gchar *label_path; /**< label file path */
+ GList *labels; /**< list of loaded labels */
+} tflite_info_s;
+
+/**
* @brief Data structure for app.
*/
typedef struct
gboolean running; /**< true when app is running */
guint received; /**< received buffer count */
+ gint current_label; /**< current label index */
+ gint new_label; /**< new label index */
+ tflite_info_s tflite_info; /**< tflite model info */
} AppData;
/**
static AppData g_app;
/**
+ * @brief Free data in tflite info structure.
+ */
+static void
+_tflite_free_info (tflite_info_s * tflite_info)
+{
+ g_return_if_fail (tflite_info != NULL);
+
+ if (tflite_info->model_path) {
+ g_free (tflite_info->model_path);
+ tflite_info->model_path = NULL;
+ }
+
+ if (tflite_info->label_path) {
+ g_free (tflite_info->label_path);
+ tflite_info->label_path = NULL;
+ }
+
+ if (tflite_info->labels) {
+ g_list_free (tflite_info->labels);
+ tflite_info->labels = NULL;
+ }
+}
+
+/**
+ * @brief Check tflite model and load labels.
+ *
+ * This example uses 'Mobilenet_1.0_224_quant' for image classification.
+ */
+static gboolean
+_tflite_init_info (tflite_info_s * tflite_info, const gchar * path)
+{
+ const gchar tflite_model[] = "mobilenet_v1_1.0_224_quant.tflite";
+ const gchar tflite_label[] = "labels.txt";
+
+ FILE *fp;
+
+ g_return_val_if_fail (tflite_info != NULL, FALSE);
+
+ tflite_info->model_path = NULL;
+ tflite_info->label_path = NULL;
+ tflite_info->labels = NULL;
+
+ /** check model file exists */
+ tflite_info->model_path = g_strdup_printf ("%s/%s", path, tflite_model);
+
+ if (access (tflite_info->model_path, F_OK) != 0) {
+ _print_log ("cannot find tflite model [%s]", tflite_info->model_path);
+ return FALSE;
+ }
+
+ /** load labels */
+ tflite_info->label_path = g_strdup_printf ("%s/%s", path, tflite_label);
+
+ if ((fp = fopen (tflite_info->label_path, "r")) != NULL) {
+ char *line = NULL;
+ size_t len = 0;
+ ssize_t read;
+ gchar *label;
+
+ while ((read = getline (&line, &len, fp)) != -1) {
+ label = g_strdup ((gchar *) line);
+ tflite_info->labels = g_list_append (tflite_info->labels, label);
+ }
+
+ if (line) {
+ free (line);
+ }
+
+ fclose (fp);
+ } else {
+ _print_log ("cannot find tflite label [%s]", tflite_info->label_path);
+ return FALSE;
+ }
+
+ _print_log ("finished to load tflite label");
+ _print_log ("total labels %d", g_list_length (tflite_info->labels));
+ return TRUE;
+}
+
+/**
+ * @brief Get label string with given index.
+ */
+static gchar *
+_tflite_get_label (tflite_info_s * tflite_info, gint index)
+{
+ guint length;
+
+ g_return_val_if_fail (tflite_info != NULL, NULL);
+ g_return_val_if_fail (tflite_info->labels != NULL, NULL);
+
+ length = g_list_length (tflite_info->labels);
+ g_return_val_if_fail (index >= 0 && index < length, NULL);
+
+ return (gchar *) g_list_nth_data (tflite_info->labels, index);
+}
+
+/**
+ * @brief Get tflite label index.
+ * @param scores array of confidence score
+ * @param len array length
+ * @return -1 if failed to get max score index
+ */
+static gint
+_get_top_label_index (guint8 * scores, guint len)
+{
+ gint i;
+ gint index = -1;
+ guint8 max_score = 0;
+
+ g_return_val_if_fail (scores != NULL, -1);
+
+ for (i = 0; i < len; i++) {
+ if (scores[i] > 0 && scores[i] > max_score) {
+ index = i;
+ max_score = scores[i];
+ }
+ }
+
+ return index;
+}
+
+/**
* @brief Free resources in app data.
*/
static void
gst_object_unref (g_app.pipeline);
g_app.pipeline = NULL;
}
+
+ _tflite_free_info (&g_app.tflite_info);
}
/**
static void
_new_data_cb (GstElement * element, GstBuffer * buffer, gpointer user_data)
{
+ /** print progress */
g_app.received++;
if (g_app.received % 150 == 0) {
_print_log ("receiving new data [%d]", g_app.received);
}
if (g_app.running) {
- /** @todo update textoverlay */
+ GstMemory *mem;
+ GstMapInfo info;
+ guint i;
+ guint num_mems;
+
+ num_mems = gst_buffer_n_memory (buffer);
+ for (i = 0; i < num_mems; i++) {
+ mem = gst_buffer_peek_memory (buffer, i);
+
+ if (gst_memory_map (mem, &info, GST_MAP_READ)) {
+ /** @todo handle data (info.data, info.size) */
+ _print_log ("received %zd", info.size);
+ g_app.new_label = _get_top_label_index (NULL, 0);
+
+ gst_memory_unmap (mem, &info);
+ }
+ }
}
}
_timer_update_result_cb (gpointer user_data)
{
if (g_app.running) {
- GstElement *textoverlay;
- gchar *tensor_res;
+ GstElement *overlay;
+ gchar *label = NULL;
- /** @todo update textoverlay */
- tensor_res = g_strdup_printf ("total received %d", g_app.received);
+ if (g_app.current_label != g_app.new_label) {
+ g_app.current_label = g_app.new_label;
- textoverlay = gst_bin_get_by_name (GST_BIN (g_app.pipeline), "tensor_res");
- g_object_set (textoverlay, "text", tensor_res, NULL);
+ overlay = gst_bin_get_by_name (GST_BIN (g_app.pipeline), "tensor_res");
- g_free (tensor_res);
- gst_object_unref (textoverlay);
+ label = _tflite_get_label (&g_app.tflite_info, g_app.current_label);
+ g_object_set (overlay, "text", (label != NULL) ? label : "", NULL);
+
+ gst_object_unref (overlay);
+ }
}
return TRUE;
int
main (int argc, char **argv)
{
- const guint width = 640;
- const guint height = 480;
+ const gchar tflite_model_path[] = "./tflite_model";
+ const guint width = 224;
+ const guint height = 224;
gchar *str_pipeline;
gulong handle_id;
/** init app variable */
g_app.running = FALSE;
g_app.received = 0;
+ g_app.current_label = -1;
+ g_app.new_label = -1;
+
+ _check_cond_err (_tflite_init_info (&g_app.tflite_info, tflite_model_path));
/** init gstreamer */
gst_init (&argc, &argv);
_check_cond_err (g_app.loop != NULL);
/** init pipeline */
- /** @todo add tensor filter */
str_pipeline =
g_strdup_printf
("v4l2src name=cam_src ! "
- "video/x-raw,width=%d,height=%d,format=RGB,framerate=30/1 ! tee name=t_raw "
+ "video/x-raw,width=640,height=480,format=RGB,framerate=30/1 ! tee name=t_raw "
"t_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! "
"videoconvert ! xvimagesink name=img_tensor "
- "t_raw. ! queue ! tensor_converter ! tensor_sink name=tensor_sink",
- width, height);
+ "t_raw. ! queue ! videoscale ! video/x-raw,width=%d,height=%d ! tensor_converter ! "
+ "tensor_filter framework=tensorflow-lite model=%s ! tensor_sink name=tensor_sink",
+ width, height, g_app.tflite_info.model_path);
g_app.pipeline = gst_parse_launch (str_pipeline, NULL);
g_free (str_pipeline);
_check_cond_err (g_app.pipeline != NULL);