2 * SPDX-License-Identifier: LGPL-2.1-only
4 * @file tensor_filter_reload_test.c
6 * @brief test case to test a filter's model reload
7 * @see https://github.com/nnstreamer/nnstreamer
8 * @author Dongju Chae <dongju.chae@samsung.com>
9 * @bug No known bugs except NYI.
15 #include <nnstreamer_util.h>
16 #include <unittest_util.h>
18 #define make_gst_element(element) do{\
19 element = gst_element_factory_make(#element, #element);\
21 g_printerr ("element %s could not be created.\n", #element);\
27 #define IMAGE_FPS (25)
28 #define EVENT_INTERVAL (1000)
29 #define EVENT_TIMEOUT (10000)
31 static gchar *input_img_path = NULL;
32 static gchar *first_model_path = NULL;
33 static gchar *second_model_path = NULL;
35 static gint return_val = 0;
38 * @brief Bus callback function
41 bus_callback (GstBus * bus, GstMessage * message, gpointer data)
43 GMainLoop *loop = data;
45 _print_log ("Got %s message\n", GST_MESSAGE_TYPE_NAME (message));
47 switch (GST_MESSAGE_TYPE (message)) {
48 case GST_MESSAGE_ERROR:{
52 gst_message_parse_error (message, &err, &debug);
53 _print_log ("Error: %s\n", err->message);
57 g_main_loop_quit (loop);
61 g_main_loop_quit (loop);
71 * @brief Find the index with a maximum value
74 get_maximum_index (guint8 * data, gsize size)
76 gsize idx, max_idx = 0;
79 for (idx = 0; idx < size; ++idx) {
80 if (data[idx] > maximum) {
90 * @brief Signal to handle new output data
93 check_output (GstElement * sink, void *data __attribute__ ((unused)))
95 static gint prev_index = -1;
102 g_signal_emit_by_name (sink, "pull-sample", &sample);
104 return GST_FLOW_ERROR;
106 buffer = gst_sample_get_buffer (sample);
108 return GST_FLOW_ERROR;
110 if (!gst_buffer_map (buffer, &info, GST_MAP_READ))
111 return GST_FLOW_ERROR;
114 * find the maximum entry; this value should be the same with
115 * the previous one even if a model is switched to the other one
117 index = get_maximum_index (info.data, info.size);
118 if (prev_index != -1 && prev_index != index) {
119 g_critical ("Output is different! %d vs %d\n", prev_index, index);
125 gst_buffer_unmap (buffer, &info);
126 gst_sample_unref (sample);
132 * @brief Reload a tensor filter's model (v1 <-> v2)
135 reload_model (GstElement * pipeline)
137 static gboolean is_first = TRUE;
138 const gchar *model_path;
139 GstElement *tensor_filter;
144 tensor_filter = gst_bin_get_by_name (GST_BIN (pipeline), "tfilter");
145 if (!tensor_filter) {
149 model_path = is_first ? second_model_path : first_model_path;
150 setPipelineStateSync (pipeline, GST_STATE_PAUSED, UNITTEST_STATECHANGE_TIMEOUT);
151 g_usleep (TEST_DEFAULT_SLEEP_TIME);
152 g_object_set (G_OBJECT (tensor_filter), "model", model_path, NULL);
154 _print_log ("Model %s is just reloaded\n", model_path);
156 is_first = !is_first;
157 setPipelineStateSync (pipeline, GST_STATE_PLAYING, UNITTEST_STATECHANGE_TIMEOUT);
158 g_usleep (TEST_DEFAULT_SLEEP_TIME);
159 /* repeat if it's playing */
161 gst_object_unref (tensor_filter);
163 return (GST_STATE (GST_ELEMENT (pipeline)) == GST_STATE_PLAYING);
167 * @brief Stop the main loop callback
170 stop_loop (GMainLoop * loop)
175 g_main_loop_quit (loop);
177 _print_log ("Now stop the loop\n");
184 * @brief Main function to evalute tensor_filter's model reload functionality
185 * @note feed the same input image to the tensor filter; So, even if a detection model
186 * is updated (mobilenet v1 <-> v2), the output should be the same for all frames.
189 main (int argc, char *argv[])
192 GstElement *pipeline;
193 GstElement *filesrc, *pngdec, *videoscale, *imagefreeze, *videoconvert;
194 GstElement *capsfilter, *tensor_converter, *tensor_filter, *appsink;
198 GError *error = NULL;
199 GOptionContext *opt_context;
200 const GOptionEntry opt_entries[] = {
201 {"input_img", 'i', G_OPTION_FLAG_NONE, G_OPTION_ARG_STRING, &input_img_path,
202 "The path of input image file",
203 "e.g., data/orange.png"},
204 {"first_model", 0, G_OPTION_FLAG_NONE, G_OPTION_ARG_STRING,
206 "The path of first model file",
207 "e.g., models/mobilenet_v1_1.0_224_quant.tflite"},
208 {"second_model", 0, G_OPTION_FLAG_NONE, G_OPTION_ARG_STRING,
210 "The path of second model file",
211 "e.g., models/mobilenet_v2_1.0_224_quant.tflite"},
216 opt_context = g_option_context_new (NULL);
217 g_option_context_add_main_entries (opt_context, opt_entries, NULL);
219 if (!g_option_context_parse (opt_context, &argc, &argv, &error)) {
220 g_printerr ("Option parsing failed: %s\n", error->message);
221 g_error_free (error);
225 g_option_context_free (opt_context);
227 if (!(input_img_path && first_model_path && second_model_path)) {
228 g_printerr ("No valid arguments provided\n");
232 gst_init (&argc, &argv);
233 loop = g_main_loop_new (NULL, FALSE);
235 /* make pipeline & elements */
236 pipeline = gst_pipeline_new ("Pipeline with a model-updatable tensor filter");
237 make_gst_element (filesrc);
238 make_gst_element (pngdec);
239 make_gst_element (videoscale);
240 make_gst_element (imagefreeze); /* feed the same input image */
241 make_gst_element (videoconvert);
242 make_gst_element (capsfilter);
243 make_gst_element (tensor_converter);
244 make_gst_element (tensor_filter);
245 make_gst_element (appsink); /* output is verified in appsink callback */
247 /* set arguments of each element */
248 g_object_set (G_OBJECT (filesrc), "location", input_img_path, NULL);
250 caps = gst_caps_new_simple ("video/x-raw",
251 "format", G_TYPE_STRING, "RGB",
252 "framerate", GST_TYPE_FRACTION, IMAGE_FPS, 1, NULL);
253 g_object_set (G_OBJECT (capsfilter), "caps", caps, NULL);
254 gst_caps_unref (caps);
256 g_object_set (G_OBJECT (tensor_filter), "name", "tfilter", NULL);
257 g_object_set (G_OBJECT (tensor_filter), "framework", "tensorflow-lite", NULL);
258 g_object_set (G_OBJECT (tensor_filter), "model", first_model_path, NULL);
259 g_object_set (G_OBJECT (tensor_filter), "is-updatable", TRUE, NULL);
261 g_object_set (G_OBJECT (appsink), "emit-signals", TRUE, NULL);
262 g_object_set (G_OBJECT (appsink), "sync", FALSE, NULL);
263 g_signal_connect (appsink, "new-sample", G_CALLBACK (check_output), NULL);
265 /* link elements to the pipeline */
266 gst_bin_add_many (GST_BIN (pipeline), filesrc,
267 pngdec, videoscale, imagefreeze, videoconvert, capsfilter,
268 tensor_converter, tensor_filter, appsink, NULL);
269 gst_element_link_many (filesrc,
270 pngdec, videoscale, imagefreeze, videoconvert, capsfilter,
271 tensor_converter, tensor_filter, appsink, NULL);
273 bus = gst_pipeline_get_bus (GST_PIPELINE (pipeline));
274 gst_bus_add_watch (bus, bus_callback, loop);
275 gst_object_unref (bus);
277 gst_element_set_state (pipeline, GST_STATE_PLAYING);
279 /* add timeout events */
280 g_timeout_add (EVENT_INTERVAL, (GSourceFunc) reload_model, pipeline);
281 g_timeout_add (EVENT_TIMEOUT, (GSourceFunc) stop_loop, loop);
283 g_main_loop_run (loop);
285 gst_element_set_state (pipeline, GST_STATE_NULL);
288 gst_object_unref (GST_OBJECT (pipeline));
289 g_main_loop_unref (loop);