Support clang/clang++
[platform/upstream/nnstreamer.git] / tests / nnstreamer_filter_reload / tensor_filter_reload_test.c
1 /**
2  * SPDX-License-Identifier: LGPL-2.1-only
3  *
4  * @file    tensor_filter_reload_test.c
5  * @data    19 Dec 2019
6  * @brief   test case to test a filter's model reload
7  * @see     https://github.com/nnstreamer/nnstreamer
8  * @author  Dongju Chae <dongju.chae@samsung.com>
9  * @bug     No known bugs except NYI.
10  */
11
12 #include <string.h>
13 #include <stdlib.h>
14 #include <gst/gst.h>
15 #include <nnstreamer_util.h>
16 #include <unittest_util.h>
17
18 #define make_gst_element(element) do{\
19   element = gst_element_factory_make(#element, #element);\
20   if (!element) {\
21     g_printerr ("element %s could not be created.\n", #element);\
22     return_val = -1;\
23     goto out_unref;\
24   }\
25 } while (0);
26
27 #define IMAGE_FPS (25)
28 #define EVENT_INTERVAL (1000)
29 #define EVENT_TIMEOUT (10000)
30
31 static gchar *input_img_path = NULL;
32 static gchar *first_model_path = NULL;
33 static gchar *second_model_path = NULL;
34
35 static gint return_val = 0;
36
37 /**
38  * @brief Bus callback function
39  */
40 static gboolean
41 bus_callback (GstBus * bus, GstMessage * message, gpointer data)
42 {
43   GMainLoop *loop = data;
44   UNUSED (bus);
45   _print_log ("Got %s message\n", GST_MESSAGE_TYPE_NAME (message));
46
47   switch (GST_MESSAGE_TYPE (message)) {
48     case GST_MESSAGE_ERROR:{
49       GError *err;
50       gchar *debug;
51
52       gst_message_parse_error (message, &err, &debug);
53       _print_log ("Error: %s\n", err->message);
54       g_error_free (err);
55       g_free (debug);
56
57       g_main_loop_quit (loop);
58       break;
59     }
60     case GST_MESSAGE_EOS:
61       g_main_loop_quit (loop);
62       break;
63     default:
64       break;
65   }
66
67   return TRUE;
68 }
69
70 /**
71  * @brief Find the index with a maximum value
72  */
73 static gint
74 get_maximum_index (guint8 * data, gsize size)
75 {
76   gsize idx, max_idx = 0;
77   guint8 maximum = 0;
78
79   for (idx = 0; idx < size; ++idx) {
80     if (data[idx] > maximum) {
81       maximum = data[idx];
82       max_idx = idx;
83     }
84   }
85
86   return max_idx;
87 }
88
89 /**
90  * @brief Signal to handle new output data
91  */
92 static GstFlowReturn
93 check_output (GstElement * sink, void *data __attribute__ ((unused)))
94 {
95   static gint prev_index = -1;
96
97   GstSample *sample;
98   GstBuffer *buffer;
99   GstMapInfo info;
100   gint index;
101
102   g_signal_emit_by_name (sink, "pull-sample", &sample);
103   if (!sample)
104     return GST_FLOW_ERROR;
105
106   buffer = gst_sample_get_buffer (sample);
107   if (!buffer)
108     return GST_FLOW_ERROR;
109
110   if (!gst_buffer_map (buffer, &info, GST_MAP_READ))
111     return GST_FLOW_ERROR;
112
113   /**
114    * find the maximum entry; this value should be the same with
115    * the previous one even if a model is switched to the other one
116    */
117   index = get_maximum_index (info.data, info.size);
118   if (prev_index != -1 && prev_index != index) {
119     g_critical ("Output is different! %d vs %d\n", prev_index, index);
120     return_val = -1;
121   }
122
123   prev_index = index;
124
125   gst_buffer_unmap (buffer, &info);
126   gst_sample_unref (sample);
127
128   return GST_FLOW_OK;
129 }
130
131 /**
132  * @brief Reload a tensor filter's model (v1 <-> v2)
133  */
134 static gboolean
135 reload_model (GstElement * pipeline)
136 {
137   static gboolean is_first = TRUE;
138   const gchar *model_path;
139   GstElement *tensor_filter;
140
141   if (!pipeline)
142     return FALSE;
143
144   tensor_filter = gst_bin_get_by_name (GST_BIN (pipeline), "tfilter");
145   if (!tensor_filter) {
146     return FALSE;
147   }
148
149   model_path = is_first ? second_model_path : first_model_path;
150   setPipelineStateSync (pipeline, GST_STATE_PAUSED, UNITTEST_STATECHANGE_TIMEOUT);
151   g_usleep (TEST_DEFAULT_SLEEP_TIME);
152   g_object_set (G_OBJECT (tensor_filter), "model", model_path, NULL);
153
154   _print_log ("Model %s is just reloaded\n", model_path);
155
156   is_first = !is_first;
157   setPipelineStateSync (pipeline, GST_STATE_PLAYING, UNITTEST_STATECHANGE_TIMEOUT);
158   g_usleep (TEST_DEFAULT_SLEEP_TIME);
159   /* repeat if it's playing */
160
161   gst_object_unref (tensor_filter);
162
163   return (GST_STATE (GST_ELEMENT (pipeline)) == GST_STATE_PLAYING);
164 }
165
166 /**
167  * @brief Stop the main loop callback
168  */
169 static gboolean
170 stop_loop (GMainLoop * loop)
171 {
172   if (!loop)
173     return FALSE;
174
175   g_main_loop_quit (loop);
176
177   _print_log ("Now stop the loop\n");
178
179   /* stop */
180   return FALSE;
181 }
182
183 /**
184  * @brief Main function to evalute tensor_filter's model reload functionality
185  * @note feed the same input image to the tensor filter; So, even if a detection model
186  * is updated (mobilenet v1 <-> v2), the output should be the same for all frames.
187  */
188 int
189 main (int argc, char *argv[])
190 {
191   GMainLoop *loop;
192   GstElement *pipeline;
193   GstElement *filesrc, *pngdec, *videoscale, *imagefreeze, *videoconvert;
194   GstElement *capsfilter, *tensor_converter, *tensor_filter, *appsink;
195   GstCaps *caps;
196   GstBus *bus;
197
198   GError *error = NULL;
199   GOptionContext *opt_context;
200   const GOptionEntry opt_entries[] = {
201     {"input_img", 'i', G_OPTION_FLAG_NONE, G_OPTION_ARG_STRING, &input_img_path,
202           "The path of input image file",
203         "e.g., data/orange.png"},
204     {"first_model", 0, G_OPTION_FLAG_NONE, G_OPTION_ARG_STRING,
205           &first_model_path,
206           "The path of first model file",
207         "e.g., models/mobilenet_v1_1.0_224_quant.tflite"},
208     {"second_model", 0, G_OPTION_FLAG_NONE, G_OPTION_ARG_STRING,
209           &second_model_path,
210           "The path of second model file",
211         "e.g., models/mobilenet_v2_1.0_224_quant.tflite"},
212     {0}
213   };
214
215   /* parse options */
216   opt_context = g_option_context_new (NULL);
217   g_option_context_add_main_entries (opt_context, opt_entries, NULL);
218
219   if (!g_option_context_parse (opt_context, &argc, &argv, &error)) {
220     g_printerr ("Option parsing failed: %s\n", error->message);
221     g_error_free (error);
222     return -1;
223   }
224
225   g_option_context_free (opt_context);
226
227   if (!(input_img_path && first_model_path && second_model_path)) {
228     g_printerr ("No valid arguments provided\n");
229     return -1;
230   }
231
232   gst_init (&argc, &argv);
233   loop = g_main_loop_new (NULL, FALSE);
234
235   /* make pipeline & elements */
236   pipeline = gst_pipeline_new ("Pipeline with a model-updatable tensor filter");
237   make_gst_element (filesrc);
238   make_gst_element (pngdec);
239   make_gst_element (videoscale);
240   make_gst_element (imagefreeze);       /* feed the same input image */
241   make_gst_element (videoconvert);
242   make_gst_element (capsfilter);
243   make_gst_element (tensor_converter);
244   make_gst_element (tensor_filter);
245   make_gst_element (appsink);   /* output is verified in appsink callback */
246
247   /* set arguments of each element */
248   g_object_set (G_OBJECT (filesrc), "location", input_img_path, NULL);
249
250   caps = gst_caps_new_simple ("video/x-raw",
251       "format", G_TYPE_STRING, "RGB",
252       "framerate", GST_TYPE_FRACTION, IMAGE_FPS, 1, NULL);
253   g_object_set (G_OBJECT (capsfilter), "caps", caps, NULL);
254   gst_caps_unref (caps);
255
256   g_object_set (G_OBJECT (tensor_filter), "name", "tfilter", NULL);
257   g_object_set (G_OBJECT (tensor_filter), "framework", "tensorflow-lite", NULL);
258   g_object_set (G_OBJECT (tensor_filter), "model", first_model_path, NULL);
259   g_object_set (G_OBJECT (tensor_filter), "is-updatable", TRUE, NULL);
260
261   g_object_set (G_OBJECT (appsink), "emit-signals", TRUE, NULL);
262   g_object_set (G_OBJECT (appsink), "sync", FALSE, NULL);
263   g_signal_connect (appsink, "new-sample", G_CALLBACK (check_output), NULL);
264
265   /* link elements to the pipeline */
266   gst_bin_add_many (GST_BIN (pipeline), filesrc,
267       pngdec, videoscale, imagefreeze, videoconvert, capsfilter,
268       tensor_converter, tensor_filter, appsink, NULL);
269   gst_element_link_many (filesrc,
270       pngdec, videoscale, imagefreeze, videoconvert, capsfilter,
271       tensor_converter, tensor_filter, appsink, NULL);
272
273   bus = gst_pipeline_get_bus (GST_PIPELINE (pipeline));
274   gst_bus_add_watch (bus, bus_callback, loop);
275   gst_object_unref (bus);
276
277   gst_element_set_state (pipeline, GST_STATE_PLAYING);
278
279   /* add timeout events */
280   g_timeout_add (EVENT_INTERVAL, (GSourceFunc) reload_model, pipeline);
281   g_timeout_add (EVENT_TIMEOUT, (GSourceFunc) stop_loop, loop);
282
283   g_main_loop_run (loop);
284
285   gst_element_set_state (pipeline, GST_STATE_NULL);
286
287 out_unref:
288   gst_object_unref (GST_OBJECT (pipeline));
289   g_main_loop_unref (loop);
290
291   return return_val;
292 }