1 /* SPDX-License-Identifier: LGPL-2.1-only */
3 * Copyright (C) 2021 Samsung Electronics Co., Ltd.
5 * @file tensor_query_client.c
7 * @brief GStreamer plugin to handle tensor query client
8 * @author Junhwan Kim <jejudo.kim@samsung.com>
9 * @see http://github.com/nnstreamer/nnstreamer
17 #include "nnstreamer_util.h"
18 #include "tensor_query_client.h"
22 #include "tensor_query_common.h"
29 * @brief Macro for debug mode.
32 #define DBG (!self->silent)
52 #define TCP_HIGHEST_PORT 65535
53 #define TCP_DEFAULT_HOST "localhost"
54 #define TCP_DEFAULT_SRV_SRC_PORT 3000
55 #define TCP_DEFAULT_CLIENT_SRC_PORT 3001
56 #define DEFAULT_CLIENT_TIMEOUT 0
57 #define DEFAULT_SILENT TRUE
58 #define DEFAULT_MAX_REQUEST 2
60 GST_DEBUG_CATEGORY_STATIC (gst_tensor_query_client_debug);
61 #define GST_CAT_DEFAULT gst_tensor_query_client_debug
64 * @brief the capabilities of the inputs.
66 static GstStaticPadTemplate sinktemplate = GST_STATIC_PAD_TEMPLATE ("sink",
72 * @brief the capabilities of the outputs.
74 static GstStaticPadTemplate srctemplate = GST_STATIC_PAD_TEMPLATE ("src",
79 #define gst_tensor_query_client_parent_class parent_class
80 G_DEFINE_TYPE (GstTensorQueryClient, gst_tensor_query_client, GST_TYPE_ELEMENT);
82 static void gst_tensor_query_client_finalize (GObject * object);
83 static void gst_tensor_query_client_set_property (GObject * object,
84 guint prop_id, const GValue * value, GParamSpec * pspec);
85 static void gst_tensor_query_client_get_property (GObject * object,
86 guint prop_id, GValue * value, GParamSpec * pspec);
88 static gboolean gst_tensor_query_client_sink_event (GstPad * pad,
89 GstObject * parent, GstEvent * event);
90 static gboolean gst_tensor_query_client_sink_query (GstPad * pad,
91 GstObject * parent, GstQuery * query);
92 static GstFlowReturn gst_tensor_query_client_chain (GstPad * pad,
93 GstObject * parent, GstBuffer * buf);
94 static GstCaps *gst_tensor_query_client_query_caps (GstTensorQueryClient * self,
95 GstPad * pad, GstCaps * filter);
98 * @brief initialize the class
101 gst_tensor_query_client_class_init (GstTensorQueryClientClass * klass)
103 GObjectClass *gobject_class;
104 GstElementClass *gstelement_class;
106 gobject_class = (GObjectClass *) klass;
107 gstelement_class = (GstElementClass *) klass;
109 gobject_class->set_property = gst_tensor_query_client_set_property;
110 gobject_class->get_property = gst_tensor_query_client_get_property;
111 gobject_class->finalize = gst_tensor_query_client_finalize;
113 /** install property goes here */
114 g_object_class_install_property (gobject_class, PROP_HOST,
115 g_param_spec_string ("host", "Host",
116 "A host address to receive the packets from query server",
117 TCP_DEFAULT_HOST, G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
118 g_object_class_install_property (gobject_class, PROP_PORT,
119 g_param_spec_uint ("port", "Port",
120 "A port number to receive the packets from query server", 0,
121 TCP_HIGHEST_PORT, TCP_DEFAULT_SRV_SRC_PORT,
122 G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
123 g_object_class_install_property (gobject_class, PROP_DEST_HOST,
124 g_param_spec_string ("dest-host", "Destination Host",
125 "A tenor query server host to send the packets",
126 TCP_DEFAULT_HOST, G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
127 g_object_class_install_property (gobject_class, PROP_DEST_PORT,
128 g_param_spec_uint ("dest-port", "Destination Port",
129 "The port of tensor query server to send the packets", 0,
130 TCP_HIGHEST_PORT, TCP_DEFAULT_CLIENT_SRC_PORT,
131 G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
132 g_object_class_install_property (gobject_class, PROP_SILENT,
133 g_param_spec_boolean ("silent", "Silent", "Produce verbose output",
134 DEFAULT_SILENT, G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
135 g_object_class_install_property (gobject_class, PROP_CONNECT_TYPE,
136 g_param_spec_enum ("connect-type", "Connect Type",
137 "The connections type between client and server.",
138 GST_TYPE_QUERY_CONNECT_TYPE, DEFAULT_CONNECT_TYPE,
139 G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
140 g_object_class_install_property (gobject_class, PROP_TOPIC,
141 g_param_spec_string ("topic", "Topic",
142 "The main topic of the host.",
143 "", G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
145 g_object_class_install_property (gobject_class, PROP_TIMEOUT,
146 g_param_spec_uint ("timeout", "timeout value",
147 "A timeout value (in ms) to wait message from query server after sending buffer to server. 0 means no wait.",
148 0, G_MAXUINT, DEFAULT_CLIENT_TIMEOUT,
149 G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
150 g_object_class_install_property (gobject_class, PROP_MAX_REQUEST,
151 g_param_spec_uint ("max-request", "Maximum number of request",
152 "Sets the maximum number of buffers to request to the query server. "
153 "If the processing speed of query server is slower than the query client, the input buffer is dropped. "
154 "Two buffers are requested by default, and 0 means that all buffers are sent to query server without drop. ",
155 0, G_MAXUINT, DEFAULT_MAX_REQUEST,
156 G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
157 gst_element_class_add_pad_template (gstelement_class,
158 gst_static_pad_template_get (&sinktemplate));
159 gst_element_class_add_pad_template (gstelement_class,
160 gst_static_pad_template_get (&srctemplate));
162 gst_element_class_set_static_metadata (gstelement_class,
163 "TensorQueryClient", "Filter/Tensor/Query",
164 "Handle querying tensor data through the network",
165 "Samsung Electronics Co., Ltd.");
167 GST_DEBUG_CATEGORY_INIT (gst_tensor_query_client_debug, "tensor_query_client",
168 0, "Tensor Query Client");
172 * @brief initialize the new element
175 gst_tensor_query_client_init (GstTensorQueryClient * self)
177 /** setup sink pad */
178 self->sinkpad = gst_pad_new_from_static_template (&sinktemplate, "sink");
179 gst_element_add_pad (GST_ELEMENT (self), self->sinkpad);
180 gst_pad_set_event_function (self->sinkpad,
181 GST_DEBUG_FUNCPTR (gst_tensor_query_client_sink_event));
182 gst_pad_set_query_function (self->sinkpad,
183 GST_DEBUG_FUNCPTR (gst_tensor_query_client_sink_query));
184 gst_pad_set_chain_function (self->sinkpad,
185 GST_DEBUG_FUNCPTR (gst_tensor_query_client_chain));
188 self->srcpad = gst_pad_new_from_static_template (&srctemplate, "src");
189 gst_element_add_pad (GST_ELEMENT (self), self->srcpad);
191 /* init properties */
192 self->silent = DEFAULT_SILENT;
193 self->connect_type = DEFAULT_CONNECT_TYPE;
194 self->host = g_strdup (TCP_DEFAULT_HOST);
195 self->port = TCP_DEFAULT_CLIENT_SRC_PORT;
196 self->dest_host = g_strdup (TCP_DEFAULT_HOST);
197 self->dest_port = TCP_DEFAULT_SRV_SRC_PORT;
199 self->in_caps_str = NULL;
200 self->timeout = DEFAULT_CLIENT_TIMEOUT;
202 self->msg_queue = g_async_queue_new ();
203 self->max_request = DEFAULT_MAX_REQUEST;
204 self->requested_num = 0;
208 * @brief finalize the object
211 gst_tensor_query_client_finalize (GObject * object)
213 GstTensorQueryClient *self = GST_TENSOR_QUERY_CLIENT (object);
214 nns_edge_data_h data_h;
218 g_free (self->dest_host);
219 self->dest_host = NULL;
220 g_free (self->topic);
222 g_free (self->in_caps_str);
223 self->in_caps_str = NULL;
225 while ((data_h = g_async_queue_try_pop (self->msg_queue))) {
226 nns_edge_data_destroy (data_h);
229 if (self->msg_queue) {
230 g_async_queue_unref (self->msg_queue);
231 self->msg_queue = NULL;
235 nns_edge_release_handle (self->edge_h);
239 G_OBJECT_CLASS (parent_class)->finalize (object);
243 * @brief set property
246 gst_tensor_query_client_set_property (GObject * object, guint prop_id,
247 const GValue * value, GParamSpec * pspec)
249 GstTensorQueryClient *self = GST_TENSOR_QUERY_CLIENT (object);
251 /** @todo DO NOT update properties (host, port, ..) while pipeline is running. */
254 if (!g_value_get_string (value)) {
255 nns_logw ("Sink host property cannot be NULL");
259 self->host = g_value_dup_string (value);
262 self->port = g_value_get_uint (value);
265 if (!g_value_get_string (value)) {
266 nns_logw ("Sink host property cannot be NULL");
269 g_free (self->dest_host);
270 self->dest_host = g_value_dup_string (value);
273 self->dest_port = g_value_get_uint (value);
275 case PROP_CONNECT_TYPE:
276 self->connect_type = g_value_get_enum (value);
279 if (!g_value_get_string (value)) {
280 nns_logw ("Topic property cannot be NULL. Query-hybrid is disabled.");
283 g_free (self->topic);
284 self->topic = g_value_dup_string (value);
287 self->timeout = g_value_get_uint (value);
290 self->silent = g_value_get_boolean (value);
292 case PROP_MAX_REQUEST:
293 self->max_request = g_value_get_uint (value);
296 G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
302 * @brief get property
305 gst_tensor_query_client_get_property (GObject * object, guint prop_id,
306 GValue * value, GParamSpec * pspec)
308 GstTensorQueryClient *self = GST_TENSOR_QUERY_CLIENT (object);
312 g_value_set_string (value, self->host);
315 g_value_set_uint (value, self->port);
318 g_value_set_string (value, self->dest_host);
321 g_value_set_uint (value, self->dest_port);
323 case PROP_CONNECT_TYPE:
324 g_value_set_enum (value, self->connect_type);
327 g_value_set_string (value, self->topic);
330 g_value_set_uint (value, self->timeout);
333 g_value_set_boolean (value, self->silent);
335 case PROP_MAX_REQUEST:
336 g_value_set_uint (value, self->max_request);
339 G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
345 * @brief Update src pad caps from tensors config.
348 gst_tensor_query_client_update_caps (GstTensorQueryClient * self,
349 const gchar * caps_str)
351 GstCaps *curr_caps, *out_caps;
352 gboolean ret = FALSE;
353 out_caps = gst_caps_from_string (caps_str);
354 silent_debug_caps (self, out_caps, "set out-caps");
356 /* Update src pad caps if it is different. */
357 curr_caps = gst_pad_get_current_caps (self->srcpad);
358 if (curr_caps == NULL || !gst_caps_is_equal (curr_caps, out_caps)) {
359 if (gst_caps_is_fixed (out_caps)) {
360 ret = gst_pad_set_caps (self->srcpad, out_caps);
362 nns_loge ("out-caps from tensor_query_serversink is not fixed. "
363 "Failed to update client src caps, out-caps: %s", caps_str);
366 /** Don't need to update when the capability is the same. */
371 gst_caps_unref (curr_caps);
373 gst_caps_unref (out_caps);
379 * @brief Parse caps from received event data.
382 _nns_edge_parse_caps (gchar * caps_str, gboolean is_src)
386 gchar *find_key = NULL;
387 gchar *ret_str = NULL;
392 strv = g_strsplit (caps_str, "@", -1);
393 num = g_strv_length (strv);
397 TRUE ? g_strdup ("query_server_src_caps") :
398 g_strdup ("query_server_sink_caps");
400 for (i = 1; i < num; i += 2) {
401 if (0 == g_strcmp0 (find_key, strv[i])) {
402 ret_str = g_strdup (strv[i + 1]);
414 * @brief nnstreamer-edge event callback.
417 _nns_edge_event_cb (nns_edge_event_h event_h, void *user_data)
419 nns_edge_event_e event_type;
420 int ret = NNS_EDGE_ERROR_NONE;
421 GstTensorQueryClient *self = (GstTensorQueryClient *) user_data;
423 if (NNS_EDGE_ERROR_NONE != nns_edge_event_get_type (event_h, &event_type)) {
424 nns_loge ("Failed to get event type!");
425 return NNS_EDGE_ERROR_NOT_SUPPORTED;
428 switch (event_type) {
429 case NNS_EDGE_EVENT_CAPABILITY:
431 GstCaps *server_caps, *client_caps;
432 GstStructure *server_st, *client_st;
433 gboolean result = FALSE;
434 gchar *ret_str, *caps_str;
436 nns_edge_event_parse_capability (event_h, &caps_str);
437 ret_str = _nns_edge_parse_caps (caps_str, TRUE);
438 nns_logd ("Received server-src caps: %s", GST_STR_NULL (ret_str));
439 client_caps = gst_caps_from_string ((gchar *) self->in_caps_str);
440 server_caps = gst_caps_from_string (ret_str);
443 /** Server framerate may vary. Let's skip comparing the framerate. */
444 gst_caps_set_simple (server_caps, "framerate", GST_TYPE_FRACTION, 0, 1,
446 gst_caps_set_simple (client_caps, "framerate", GST_TYPE_FRACTION, 0, 1,
449 server_st = gst_caps_get_structure (server_caps, 0);
450 client_st = gst_caps_get_structure (client_caps, 0);
452 if (gst_structure_is_tensor_stream (server_st)) {
453 GstTensorsConfig server_config, client_config;
455 gst_tensors_config_from_structure (&server_config, server_st);
456 gst_tensors_config_from_structure (&client_config, client_st);
458 result = gst_tensors_config_is_equal (&server_config, &client_config);
461 if (result || gst_caps_can_intersect (client_caps, server_caps)) {
462 /** Update client src caps */
463 ret_str = _nns_edge_parse_caps (caps_str, FALSE);
464 nns_logd ("Received server-sink caps: %s", GST_STR_NULL (ret_str));
465 if (!gst_tensor_query_client_update_caps (self, ret_str)) {
466 nns_loge ("Failed to update client source caps.");
467 ret = NNS_EDGE_ERROR_UNKNOWN;
471 /* respond deny with src caps string */
472 nns_loge ("Query caps is not acceptable!");
473 ret = NNS_EDGE_ERROR_UNKNOWN;
476 gst_caps_unref (server_caps);
477 gst_caps_unref (client_caps);
481 case NNS_EDGE_EVENT_NEW_DATA_RECEIVED:
483 nns_edge_data_h data;
485 nns_edge_event_parse_new_data (event_h, &data);
486 g_async_queue_push (self->msg_queue, data);
497 * @brief Internal function to create edge handle.
500 gst_tensor_query_client_create_edge_handle (GstTensorQueryClient * self)
502 gboolean started = FALSE;
503 gchar *prev_caps = NULL;
506 /* Already created, compare caps string. */
508 ret = nns_edge_get_info (self->edge_h, "CAPS", &prev_caps);
510 if (ret != NNS_EDGE_ERROR_NONE || !prev_caps ||
511 !g_str_equal (prev_caps, self->in_caps_str)) {
512 /* Capability is changed, close old handle. */
513 nns_edge_release_handle (self->edge_h);
520 ret = nns_edge_create_handle ("TEMP_ID", self->connect_type,
521 NNS_EDGE_NODE_TYPE_QUERY_CLIENT, &self->edge_h);
522 if (ret != NNS_EDGE_ERROR_NONE)
525 nns_edge_set_event_callback (self->edge_h, _nns_edge_event_cb, self);
528 nns_edge_set_info (self->edge_h, "TOPIC", self->topic);
530 nns_edge_set_info (self->edge_h, "HOST", self->host);
531 if (self->port > 0) {
532 gchar *port = g_strdup_printf ("%u", self->port);
533 nns_edge_set_info (self->edge_h, "PORT", port);
536 nns_edge_set_info (self->edge_h, "CAPS", self->in_caps_str);
538 ret = nns_edge_start (self->edge_h);
539 if (ret != NNS_EDGE_ERROR_NONE) {
541 ("Failed to start NNStreamer-edge. Please check server IP and port.");
545 ret = nns_edge_connect (self->edge_h, self->dest_host, self->dest_port);
546 if (ret != NNS_EDGE_ERROR_NONE) {
547 nns_loge ("Failed to connect to edge server!");
555 nns_edge_release_handle (self->edge_h);
563 * @brief This function handles sink event.
566 gst_tensor_query_client_sink_event (GstPad * pad,
567 GstObject * parent, GstEvent * event)
569 GstTensorQueryClient *self = GST_TENSOR_QUERY_CLIENT (parent);
571 GST_DEBUG_OBJECT (self, "Received %s event: %" GST_PTR_FORMAT,
572 GST_EVENT_TYPE_NAME (event), event);
574 switch (GST_EVENT_TYPE (event)) {
580 gst_event_parse_caps (event, &caps);
581 g_free (self->in_caps_str);
582 self->in_caps_str = gst_caps_to_string (caps);
584 ret = gst_tensor_query_client_create_edge_handle (self);
586 nns_loge ("Failed to create edge handle, cannot start query client.");
588 gst_event_unref (event);
595 return gst_pad_event_default (pad, parent, event);
599 * @brief This function handles sink pad query.
602 gst_tensor_query_client_sink_query (GstPad * pad,
603 GstObject * parent, GstQuery * query)
605 GstTensorQueryClient *self = GST_TENSOR_QUERY_CLIENT (parent);
607 GST_DEBUG_OBJECT (self, "Received %s query: %" GST_PTR_FORMAT,
608 GST_QUERY_TYPE_NAME (query), query);
610 switch (GST_QUERY_TYPE (query)) {
616 gst_query_parse_caps (query, &filter);
617 caps = gst_tensor_query_client_query_caps (self, pad, filter);
619 gst_query_set_caps_result (query, caps);
620 gst_caps_unref (caps);
623 case GST_QUERY_ACCEPT_CAPS:
626 GstCaps *template_caps;
627 gboolean res = FALSE;
629 gst_query_parse_accept_caps (query, &caps);
630 silent_debug_caps (self, caps, "accept-caps");
632 if (gst_caps_is_fixed (caps)) {
633 template_caps = gst_pad_get_pad_template_caps (pad);
635 res = gst_caps_can_intersect (template_caps, caps);
636 gst_caps_unref (template_caps);
639 gst_query_set_accept_caps_result (query, res);
646 return gst_pad_query_default (pad, parent, query);
650 * @brief Chain function, this function does the actual processing.
653 gst_tensor_query_client_chain (GstPad * pad,
654 GstObject * parent, GstBuffer * buf)
656 GstTensorQueryClient *self = GST_TENSOR_QUERY_CLIENT (parent);
657 GstBuffer *out_buf = NULL;
658 GstFlowReturn res = GST_FLOW_OK;
659 nns_edge_data_h data_h;
660 guint i, num_mems, num_data;
662 GstMemory *mem[NNS_TENSOR_SIZE_LIMIT];
663 GstMapInfo map[NNS_TENSOR_SIZE_LIMIT];
667 ret = nns_edge_data_create (&data_h);
668 if (ret != NNS_EDGE_ERROR_NONE) {
669 nns_loge ("Failed to create data handle in client chain.");
670 return GST_FLOW_ERROR;
673 num_mems = gst_buffer_n_memory (buf);
674 for (i = 0; i < num_mems; i++) {
675 mem[i] = gst_buffer_peek_memory (buf, i);
676 if (!gst_memory_map (mem[i], &map[i], GST_MAP_READ)) {
677 ml_loge ("Cannot map the %uth memory in gst-buffer.", i);
681 nns_edge_data_add (data_h, map[i].data, map[i].size, NULL);
684 nns_edge_get_info (self->edge_h, "client_id", &val);
685 nns_edge_data_set_info (data_h, "client_id", val);
688 if (self->requested_num > self->max_request) {
690 ("the processing speed of the query server is too slow. Drop the input buffer.");
692 if (NNS_EDGE_ERROR_NONE != nns_edge_send (self->edge_h, data_h)) {
693 nns_logi ("Failed to publish to server node.");
696 self->requested_num++;
699 nns_edge_data_destroy (data_h);
701 data_h = g_async_queue_timeout_pop (self->msg_queue,
702 self->timeout * G_TIME_SPAN_MILLISECOND);
704 self->requested_num--;
705 ret = nns_edge_data_get_count (data_h, &num_data);
706 if (ret != NNS_EDGE_ERROR_NONE || num_data == 0) {
707 nns_loge ("Failed to get the number of memories of the edge data.");
708 res = GST_FLOW_ERROR;
712 out_buf = gst_buffer_new ();
713 for (i = 0; i < num_data; i++) {
718 nns_edge_data_get (data_h, i, &data, &data_len);
719 new_data = _g_memdup (data, data_len);
720 gst_buffer_append_memory (out_buf,
721 gst_memory_new_wrapped (0, new_data, data_len, 0,
722 data_len, new_data, g_free));
724 /* metadata from incoming buffer */
725 gst_buffer_copy_into (out_buf, buf, GST_BUFFER_COPY_METADATA, 0, -1);
727 res = gst_pad_push (self->srcpad, out_buf);
732 nns_edge_data_destroy (data_h);
735 for (i = 0; i < num_mems; i++)
736 gst_memory_unmap (mem[i], &map[i]);
738 gst_buffer_unref (buf);
743 * @brief Get pad caps for caps negotiation.
746 gst_tensor_query_client_query_caps (GstTensorQueryClient * self, GstPad * pad,
751 caps = gst_pad_get_current_caps (pad);
753 /** pad don't have current caps. use the template caps */
754 caps = gst_pad_get_pad_template_caps (pad);
757 silent_debug_caps (self, caps, "caps");
758 silent_debug_caps (self, filter, "filter");
761 GstCaps *intersection;
763 gst_caps_intersect_full (filter, caps, GST_CAPS_INTERSECT_FIRST);
765 gst_caps_unref (caps);
769 silent_debug_caps (self, caps, "result");