}
/**
- * @brief Score threshold of tflite model.
- */
-#define THRESHOLD (0.8)
-
-/**
* @brief Data structure for tflite model info.
*/
typedef struct
gchar *model_path; /**< tflite model file path */
gchar *label_path; /**< label file path */
GList *labels; /**< list of loaded labels */
+ guint total_labels; /**< count of labels */
} tflite_info_s;
/**
gboolean running; /**< true when app is running */
guint received; /**< received buffer count */
- gint current_label; /**< current label index */
- gint new_label; /**< new label index */
+ gint current_label_index; /**< current label index */
+ gint new_label_index; /**< new label index */
tflite_info_s tflite_info; /**< tflite model info */
} AppData;
return FALSE;
}
- _print_log ("finished to load tflite label");
- _print_log ("total labels %d", g_list_length (tflite_info->labels));
+ tflite_info->total_labels = g_list_length (tflite_info->labels);
+ _print_log ("finished to load labels, total %d", tflite_info->total_labels);
return TRUE;
}
}
/**
- * @brief Get tflite label index.
- * @param scores array of confidence score
+ * @brief Update tflite label index with max score.
+ * @param scores array of scores
* @param len array length
- * @return -1 if failed to get max score index
+ * @return None
*/
-static gint
-_get_top_label_index (guint8 * scores, guint len)
+static void
+_update_top_label_index (guint8 * scores, guint len)
{
gint i;
gint index = -1;
guint8 max_score = 0;
- g_return_val_if_fail (scores != NULL, -1);
+ /** -1 if failed to get max score index */
+ g_app.new_label_index = -1;
+
+ g_return_if_fail (scores != NULL);
+ g_return_if_fail (len == g_app.tflite_info.total_labels);
for (i = 0; i < len; i++) {
if (scores[i] > 0 && scores[i] > max_score) {
}
}
- return index;
+ g_app.new_label_index = index;
}
/**
mem = gst_buffer_peek_memory (buffer, i);
if (gst_memory_map (mem, &info, GST_MAP_READ)) {
- /** @todo handle data (info.data, info.size) */
- _print_log ("received %zd", info.size);
- g_app.new_label = _get_top_label_index (NULL, 0);
+ /** update label index with max score */
+ _update_top_label_index (info.data, (guint) info.size);
gst_memory_unmap (mem, &info);
}
GstElement *overlay;
gchar *label = NULL;
- if (g_app.current_label != g_app.new_label) {
- g_app.current_label = g_app.new_label;
+ if (g_app.current_label_index != g_app.new_label_index) {
+ g_app.current_label_index = g_app.new_label_index;
overlay = gst_bin_get_by_name (GST_BIN (g_app.pipeline), "tensor_res");
- label = _tflite_get_label (&g_app.tflite_info, g_app.current_label);
+ label = _tflite_get_label (&g_app.tflite_info, g_app.current_label_index);
g_object_set (overlay, "text", (label != NULL) ? label : "", NULL);
gst_object_unref (overlay);
main (int argc, char **argv)
{
const gchar tflite_model_path[] = "./tflite_model";
+ /** 224x224 for tflite model */
const guint width = 224;
const guint height = 224;
/** init app variable */
g_app.running = FALSE;
g_app.received = 0;
- g_app.current_label = -1;
- g_app.new_label = -1;
+ g_app.current_label_index = -1;
+ g_app.new_label_index = -1;
_check_cond_err (_tflite_init_info (&g_app.tflite_info, tflite_model_path));
"t_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! "
"videoconvert ! xvimagesink name=img_tensor "
"t_raw. ! queue ! videoscale ! video/x-raw,width=%d,height=%d ! tensor_converter ! "
- "tensor_filter framework=tensorflow-lite model=%s ! tensor_sink name=tensor_sink",
+ "tensor_filter framework=tensorflow-lite model=%s ! "
+ "tensor_sink name=tensor_sink",
width, height, g_app.tflite_info.model_path);
g_app.pipeline = gst_parse_launch (str_pipeline, NULL);
g_free (str_pipeline);
Pipeline :
v4l2src -- tee -- textoverlay -- videoconvert -- xvimagesink
|
- --- tensor_converter -- tensor_filter -- tensor_sink
+ --- videoscale -- tensor_converter -- tensor_filter -- tensor_sink
This app displays video sink (xvimagesink).
+
'tensor_filter' for image recognition.
+Download tflite moel 'Mobilenet_1.0_224_quant' from below link,
+https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md#image-classification-quantized-models
+
'tensor_sink' updates recognition result to display in textoverlay.
Run example :
See https://lazka.github.io/pgi-docs/#Gst-1.0 for Gst API details.
"""
+import os
+import sys
import gi
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GObject
class NNStreamerExample:
"""NNStreamer example for image recognition."""
- def __init__(self):
+ def __init__(self, argv=None):
self.loop = None
self.pipeline = None
self.running = False
self.received = 0
+ self.current_label_index = -1
+ self.new_label_index = -1
+ self.tflite_model = ''
+ self.tflite_labels = []
+
+ if not self.tflite_init():
+ raise Exception
GObject.threads_init()
- Gst.init(None)
+ Gst.init(argv)
def run_example(self):
"""Init pipeline and run example.
self.loop = GObject.MainLoop()
# init pipeline
- # TODO: add tensor filter
self.pipeline = Gst.parse_launch(
- "v4l2src name=cam_src ! "
- "video/x-raw,width=640,height=480,format=RGB,framerate=30/1 ! tee name=t_raw "
- "t_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! "
- "videoconvert ! xvimagesink name=img_tensor "
- "t_raw. ! queue ! tensor_converter ! tensor_sink name=tensor_sink"
+ 'v4l2src name=cam_src ! '
+ 'video/x-raw,width=640,height=480,format=RGB,framerate=30/1 ! tee name=t_raw '
+ 't_raw. ! queue ! textoverlay name=tensor_res font-desc=\"Sans, 24\" ! '
+ 'videoconvert ! xvimagesink name=img_tensor '
+ 't_raw. ! queue ! videoscale ! video/x-raw,width=224,height=224 ! tensor_converter ! '
+ f'tensor_filter framework=tensorflow-lite model={self.tflite_model} ! '
+ 'tensor_sink name=tensor_sink'
)
# bus and message callback
bus = self.pipeline.get_bus()
bus.add_signal_watch()
- bus.connect("message", self.on_bus_message)
+ bus.connect('message', self.on_bus_message)
# tensor sink signal : new data callback
- tensor_sink = self.pipeline.get_by_name("tensor_sink")
- tensor_sink.connect("new-data", self.on_new_data)
+ tensor_sink = self.pipeline.get_by_name('tensor_sink')
+ tensor_sink.connect('new-data', self.on_new_data)
# timer to update result
GObject.timeout_add(500, self.on_timer_update_result)
self.running = True
# set window title
- self.set_window_title("img_tensor", "NNStreamer Example")
+ self.set_window_title('img_tensor', 'NNStreamer Example')
# run main loop
self.loop.run()
:return: None
"""
if message.type == Gst.MessageType.EOS:
- print("received eos message")
+ print('received eos message')
self.loop.quit()
elif message.type == Gst.MessageType.ERROR:
error, debug = message.parse_error()
- print(f"error {error} {debug}")
+ print(f'error {error} {debug}')
self.loop.quit()
elif message.type == Gst.MessageType.WARNING:
error, debug = message.parse_warning()
- print(f"warning {error} {debug}")
+ print(f'warning {error} {debug}')
elif message.type == Gst.MessageType.STREAM_START:
- print("received start message")
+ print('received start message')
def on_new_data(self, sink, buffer):
"""Callback for tensor sink signal.
# print progress
self.received += 1
if (self.received % 150) == 0:
- print(f"receiving new data [{self.received}]")
+ print(f'receiving new data [{self.received}]')
if self.running:
- # TODO: update textoverlay
for idx in range(buffer.n_memory()):
mem = buffer.peek_memory(idx)
result, mapinfo = mem.map(Gst.MapFlags.READ)
if result:
- # print(f"received {mapinfo.size}")
+ # update label index with max score
+ self.update_top_label_index(mapinfo.data, mapinfo.size)
mem.unmap(mapinfo)
def on_timer_update_result(self):
:return: True to ensure the timer continues
"""
if self.running:
- # TODO: update textoverlay
- tensor_res = f"total received {self.received}"
-
- textoverlay = self.pipeline.get_by_name("tensor_res")
- textoverlay.set_property("text", tensor_res)
+ if self.current_label_index != self.new_label_index:
+ # update textoverlay
+ self.current_label_index = self.new_label_index
+ label = self.tflite_get_label(self.current_label_index)
+ textoverlay = self.pipeline.get_by_name('tensor_res')
+ textoverlay.set_property('text', label)
return True
def set_window_title(self, name, title):
"""
element = self.pipeline.get_by_name(name)
if element is not None:
- pad = element.get_static_pad("sink")
+ pad = element.get_static_pad('sink')
if pad is not None:
tags = Gst.TagList.new_empty()
- tags.add_value(Gst.TagMergeMode.APPEND, "title", title)
+ tags.add_value(Gst.TagMergeMode.APPEND, 'title', title)
pad.send_event(Gst.Event.new_tag(tags))
+ def tflite_init(self):
+ """Check tflite model and load labels.
+
+ :return: True if successfully initialized
+ """
+ tflite_model = 'mobilenet_v1_1.0_224_quant.tflite'
+ tflite_label = 'labels.txt'
+ current_folder = os.path.dirname(os.path.abspath(__file__))
+ model_folder = os.path.join(current_folder, 'tflite_model')
+
+ # check model file exists
+ self.tflite_model = os.path.join(model_folder, tflite_model)
+ if not os.path.exists(self.tflite_model):
+ print(f'cannot find tflite model [{self.tflite_model}]')
+ return False
+
+ # load labels
+ label_path = os.path.join(model_folder, tflite_label)
+ try:
+ with open(label_path, 'r') as label_file:
+ for line in label_file.readlines():
+ self.tflite_labels.append(line)
+ except FileNotFoundError:
+ print(f'cannot find tflite label [{label_path}]')
+ return False
+
+ print(f'finished to load labels, total {len(self.tflite_labels)}')
+ return True
+
+ def tflite_get_label(self, index):
+ """Get label string with given index.
+
+ :param index: index for label
+ :return: label string
+ """
+ try:
+ label = self.tflite_labels[index]
+ except IndexError:
+ label = ''
+ return label
+
+ def update_top_label_index(self, data, data_size):
+ """Update tflite label index with max score.
+
+ :param data: array of scores
+ :param data_size: data size
+ :return: None
+ """
+ # -1 if failed to get max score index
+ self.new_label_index = -1
+
+ if data_size == len(self.tflite_labels):
+ scores = [data[i] for i in range(data_size)]
+ max_score = max(scores)
+ if max_score > 0:
+ self.new_label_index = scores.index(max_score)
+ else:
+ print(f'unexpected data size {data_size}')
+
-if __name__ == "__main__":
- example = NNStreamerExample()
+if __name__ == '__main__':
+ example = NNStreamerExample(sys.argv[1:])
example.run_example()