28f4cf2fa069d02ebd78822b90bc4184349c656e
[platform/upstream/gstreamer.git] / subprojects / gst-plugins-bad / ext / onnx / gstonnxobjectdetector.cpp
1 /*
2  * GStreamer gstreamer-onnxobjectdetector
3  * Copyright (C) 2021 Collabora Ltd.
4  *
5  * gstonnxobjectdetector.c
6  * 
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Library General Public
9  * License as published by the Free Software Foundation; either
10  * version 2 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Library General Public License for more details.
16  *
17  * You should have received a copy of the GNU Library General Public
18  * License along with this library; if not, write to the
19  * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
20  * Boston, MA 02110-1301, USA.
21  */
22
23 /**
24  * SECTION:element-onnxobjectdetector
25  * @short_description: Detect objects in video frame
26  *
27  * This element can apply a generic ONNX object detection model such as YOLO or SSD
28  * to each video frame.
29  *
30  * To install ONNX on your system, recursively clone this repository
31  * https://github.com/microsoft/onnxruntime.git
32  *
33  * and build and install with cmake:
34  *
35  * CPU:
36  *
37  *  cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \
38  *  $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
39  *
40  *
41  * GPU :
42  *
43  * cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \
44  * -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \
45  *  $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
46  *
47  *
48  * where :
49  *
50  * 1. $SRC_DIR and $BUILD_DIR are local source and build directories
51  * 2. To run with CUDA, both CUDA and cuDNN libraries must be installed.
52  *    $CUDA_PATH is an environment variable set to the CUDA root path.
53  *    On Linux, it would be /usr/local/cuda-XX.X where XX.X is the installed version of CUDA.
54  *
55  *
56  * ## Example launch command:
57  *
58  * (note: an object detection model has 3 or 4 output nodes, but there is no naming convention
59  * to indicate which node outputs the bounding box, which node outputs the label, etc.
60  * So, the `onnxobjectdetector` element has properties to map each node's functionality to its
61  * respective node index in the specified model )
62  *
63  * ```
64  * GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
65  * location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
66  * videoconvert ! \
67  * onnxobjectdetector \
68  * box-node-index=0 \
69  * class-node-index=1 \
70  * score-node-index=2 \
71  * detection-node-index=3 \
72  * execution-provider=cpu \
73  * model-file=model.onnx \
74  * label-file=COCO_classes.txt  !  \
75  * videoconvert ! \
76  * autovideosink
77  * ```
78  */
79
80 #ifdef HAVE_CONFIG_H
81 #include "config.h"
82 #endif
83
84 #include "gstonnxobjectdetector.h"
85 #include "gstonnxclient.h"
86
87 #include <gst/gst.h>
88 #include <gst/video/video.h>
89 #include <gst/video/gstvideometa.h>
90 #include <stdlib.h>
91 #include <string.h>
92 #include <glib.h>
93
94 GST_DEBUG_CATEGORY_STATIC (onnx_object_detector_debug);
95 #define GST_CAT_DEFAULT onnx_object_detector_debug
96 #define GST_ONNX_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_ptr))
97 GST_ELEMENT_REGISTER_DEFINE (onnx_object_detector, "onnxobjectdetector",
98     GST_RANK_PRIMARY, GST_TYPE_ONNX_OBJECT_DETECTOR);
99
100 /* GstOnnxObjectDetector properties */
101 enum
102 {
103   PROP_0,
104   PROP_MODEL_FILE,
105   PROP_LABEL_FILE,
106   PROP_SCORE_THRESHOLD,
107   PROP_DETECTION_NODE_INDEX,
108   PROP_BOUNDING_BOX_NODE_INDEX,
109   PROP_SCORE_NODE_INDEX,
110   PROP_CLASS_NODE_INDEX,
111   PROP_INPUT_IMAGE_FORMAT,
112   PROP_OPTIMIZATION_LEVEL,
113   PROP_EXECUTION_PROVIDER
114 };
115
116
117 #define GST_ONNX_OBJECT_DETECTOR_DEFAULT_EXECUTION_PROVIDER    GST_ONNX_EXECUTION_PROVIDER_CPU
118 #define GST_ONNX_OBJECT_DETECTOR_DEFAULT_OPTIMIZATION_LEVEL    GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED
119 #define GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD       0.3f     /* 0 to 1 */
120
121 static GstStaticPadTemplate gst_onnx_object_detector_src_template =
122 GST_STATIC_PAD_TEMPLATE ("src",
123     GST_PAD_SRC,
124     GST_PAD_ALWAYS,
125     GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
126     );
127
128 static GstStaticPadTemplate gst_onnx_object_detector_sink_template =
129 GST_STATIC_PAD_TEMPLATE ("sink",
130     GST_PAD_SINK,
131     GST_PAD_ALWAYS,
132     GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
133     );
134
135 static void gst_onnx_object_detector_set_property (GObject * object,
136     guint prop_id, const GValue * value, GParamSpec * pspec);
137 static void gst_onnx_object_detector_get_property (GObject * object,
138     guint prop_id, GValue * value, GParamSpec * pspec);
139 static void gst_onnx_object_detector_finalize (GObject * object);
140 static GstFlowReturn gst_onnx_object_detector_transform_ip (GstBaseTransform *
141     trans, GstBuffer * buf);
142 static gboolean gst_onnx_object_detector_process (GstBaseTransform * trans,
143     GstBuffer * buf);
144 static gboolean gst_onnx_object_detector_create_session (GstBaseTransform * trans);
145 static GstCaps *gst_onnx_object_detector_transform_caps (GstBaseTransform *
146     trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
147
148 G_DEFINE_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector,
149     GST_TYPE_BASE_TRANSFORM);
150
151 static void
152 gst_onnx_object_detector_class_init (GstOnnxObjectDetectorClass * klass)
153 {
154   GObjectClass *gobject_class = (GObjectClass *) klass;
155   GstElementClass *element_class = (GstElementClass *) klass;
156   GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
157
158   GST_DEBUG_CATEGORY_INIT (onnx_object_detector_debug, "onnxobjectdetector",
159       0, "onnx_objectdetector");
160   gobject_class->set_property = gst_onnx_object_detector_set_property;
161   gobject_class->get_property = gst_onnx_object_detector_get_property;
162   gobject_class->finalize = gst_onnx_object_detector_finalize;
163
164   /**
165    * GstOnnxObjectDetector:model-file
166    *
167    * ONNX model file
168    *
169    * Since: 1.20
170    */
171   g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE,
172       g_param_spec_string ("model-file",
173           "ONNX model file", "ONNX model file", NULL, (GParamFlags)
174           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
175
176   /**
177    * GstOnnxObjectDetector:label-file
178    *
179    * Label file for ONNX model
180    *
181    * Since: 1.20
182    */
183   g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE,
184       g_param_spec_string ("label-file",
185           "Label file", "Label file associated with model", NULL, (GParamFlags)
186           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
187
188
189   /**
190    * GstOnnxObjectDetector:detection-node-index
191    *
192    * Index of model detection node
193    *
194    * Since: 1.20
195    */
196   g_object_class_install_property (G_OBJECT_CLASS (klass),
197       PROP_DETECTION_NODE_INDEX,
198       g_param_spec_int ("detection-node-index",
199           "Detection node index",
200           "Index of neural network output node corresponding to number of detected objects",
201           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
202           GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
203                   GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
204           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
205
206
207   /**
208    * GstOnnxObjectDetector:bounding-box-node-index
209    *
210    * Index of model bounding box node
211    *
212    * Since: 1.20
213    */
214   g_object_class_install_property (G_OBJECT_CLASS (klass),
215       PROP_BOUNDING_BOX_NODE_INDEX,
216       g_param_spec_int ("box-node-index",
217           "Bounding box node index",
218           "Index of neural network output node corresponding to bounding box",
219           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
220           GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
221                   GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
222           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
223
224   /**
225    * GstOnnxObjectDetector:score-node-index
226    *
227    * Index of model score node
228    *
229    * Since: 1.20
230    */
231   g_object_class_install_property (G_OBJECT_CLASS (klass),
232       PROP_SCORE_NODE_INDEX,
233       g_param_spec_int ("score-node-index",
234           "Score node index",
235           "Index of neural network output node corresponding to score",
236           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
237           GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
238                   GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
239           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
240
241   /**
242    * GstOnnxObjectDetector:class-node-index
243    *
244    * Index of model class (label) node
245    *
246    * Since: 1.20
247    */
248   g_object_class_install_property (G_OBJECT_CLASS (klass),
249       PROP_CLASS_NODE_INDEX,
250       g_param_spec_int ("class-node-index",
251           "Class node index",
252           "Index of neural network output node corresponding to class (label)",
253           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
254           GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
255                   GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
256           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
257
258
259   /**
260    * GstOnnxObjectDetector:score-threshold
261    *
262    * Threshold for deciding when to remove boxes based on score
263    *
264    * Since: 1.20
265    */
266   g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD,
267       g_param_spec_float ("score-threshold",
268           "Score threshold",
269           "Threshold for deciding when to remove boxes based on score",
270           0.0, 1.0,
271           GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
272           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
273
274   /**
275    * GstOnnxObjectDetector:input-image-format
276    *
277    * Model input image format
278    *
279    * Since: 1.20
280    */
281   g_object_class_install_property (G_OBJECT_CLASS (klass),
282       PROP_INPUT_IMAGE_FORMAT,
283       g_param_spec_enum ("input-image-format",
284           "Input image format",
285           "Input image format",
286           GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT,
287           GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, (GParamFlags)
288           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
289
290    /**
291     * GstOnnxObjectDetector:optimization-level
292     *
293     * ONNX optimization level
294     *
295     * Since: 1.20
296     */
297   g_object_class_install_property (G_OBJECT_CLASS (klass),
298       PROP_OPTIMIZATION_LEVEL,
299       g_param_spec_enum ("optimization-level",
300           "Optimization level",
301           "ONNX optimization level",
302           GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
303           GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags)
304           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
305
306   /**
307    * GstOnnxObjectDetector:execution-provider
308    *
309    * ONNX execution provider
310    *
311    * Since: 1.20
312    */
313   g_object_class_install_property (G_OBJECT_CLASS (klass),
314       PROP_EXECUTION_PROVIDER,
315       g_param_spec_enum ("execution-provider",
316           "Execution provider",
317           "ONNX execution provider",
318           GST_TYPE_ONNX_EXECUTION_PROVIDER,
319           GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
320           (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
321
322   gst_element_class_set_static_metadata (element_class, "onnxobjectdetector",
323       "Filter/Effect/Video",
324       "Apply neural network to detect objects in video frames",
325       "Aaron Boxer <aaron.boxer@collabora.com>, Marcus Edel <marcus.edel@collabora.com>");
326   gst_element_class_add_pad_template (element_class,
327       gst_static_pad_template_get (&gst_onnx_object_detector_sink_template));
328   gst_element_class_add_pad_template (element_class,
329       gst_static_pad_template_get (&gst_onnx_object_detector_src_template));
330   basetransform_class->transform_ip =
331       GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_ip);
332   basetransform_class->transform_caps =
333       GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_caps);
334 }
335
336 static void
337 gst_onnx_object_detector_init (GstOnnxObjectDetector * self)
338 {
339   self->onnx_ptr = new GstOnnxNamespace::GstOnnxClient ();
340   self->onnx_disabled = false;
341 }
342
343 static void
344 gst_onnx_object_detector_finalize (GObject * object)
345 {
346   GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
347
348   g_free (self->model_file);
349   delete GST_ONNX_MEMBER (self);
350   G_OBJECT_CLASS (gst_onnx_object_detector_parent_class)->finalize (object);
351 }
352
353 static void
354 gst_onnx_object_detector_set_property (GObject * object, guint prop_id,
355     const GValue * value, GParamSpec * pspec)
356 {
357   GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
358   const gchar *filename;
359   auto onnxClient = GST_ONNX_MEMBER (self);
360
361   switch (prop_id) {
362     case PROP_MODEL_FILE:
363       filename = g_value_get_string (value);
364       if (filename
365           && g_file_test (filename,
366               (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
367         if (self->model_file)
368           g_free (self->model_file);
369         self->model_file = g_strdup (filename);
370       } else {
371         GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
372         gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
373       }
374       break;
375     case PROP_LABEL_FILE:
376       filename = g_value_get_string (value);
377       if (filename
378           && g_file_test (filename,
379               (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
380         if (self->label_file)
381           g_free (self->label_file);
382         self->label_file = g_strdup (filename);
383       } else {
384         GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
385       }
386       break;
387     case PROP_SCORE_THRESHOLD:
388       GST_OBJECT_LOCK (self);
389       self->score_threshold = g_value_get_float (value);
390       GST_OBJECT_UNLOCK (self);
391       break;
392     case PROP_OPTIMIZATION_LEVEL:
393       self->optimization_level =
394           (GstOnnxOptimizationLevel) g_value_get_enum (value);
395       break;
396     case PROP_EXECUTION_PROVIDER:
397       self->execution_provider =
398           (GstOnnxExecutionProvider) g_value_get_enum (value);
399       break;
400     case PROP_DETECTION_NODE_INDEX:
401       onnxClient->setOutputNodeIndex
402           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
403           g_value_get_int (value));
404       break;
405     case PROP_BOUNDING_BOX_NODE_INDEX:
406       onnxClient->setOutputNodeIndex
407           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
408           g_value_get_int (value));
409       break;
410       break;
411     case PROP_SCORE_NODE_INDEX:
412       onnxClient->setOutputNodeIndex
413           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
414           g_value_get_int (value));
415       break;
416       break;
417     case PROP_CLASS_NODE_INDEX:
418       onnxClient->setOutputNodeIndex
419           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
420           g_value_get_int (value));
421       break;
422     case PROP_INPUT_IMAGE_FORMAT:
423       onnxClient->setInputImageFormat ((GstMlModelInputImageFormat)
424           g_value_get_enum (value));
425       break;
426     default:
427       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
428       break;
429   }
430 }
431
432 static void
433 gst_onnx_object_detector_get_property (GObject * object, guint prop_id,
434     GValue * value, GParamSpec * pspec)
435 {
436   GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
437   auto onnxClient = GST_ONNX_MEMBER (self);
438
439   switch (prop_id) {
440     case PROP_MODEL_FILE:
441       g_value_set_string (value, self->model_file);
442       break;
443     case PROP_LABEL_FILE:
444       g_value_set_string (value, self->label_file);
445       break;
446     case PROP_SCORE_THRESHOLD:
447       GST_OBJECT_LOCK (self);
448       g_value_set_float (value, self->score_threshold);
449       GST_OBJECT_UNLOCK (self);
450       break;
451     case PROP_OPTIMIZATION_LEVEL:
452       g_value_set_enum (value, self->optimization_level);
453       break;
454     case PROP_EXECUTION_PROVIDER:
455       g_value_set_enum (value, self->execution_provider);
456       break;
457     case PROP_DETECTION_NODE_INDEX:
458       g_value_set_int (value,
459           onnxClient->getOutputNodeIndex
460           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION));
461       break;
462     case PROP_BOUNDING_BOX_NODE_INDEX:
463       g_value_set_int (value,
464           onnxClient->getOutputNodeIndex
465           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX));
466       break;
467       break;
468     case PROP_SCORE_NODE_INDEX:
469       g_value_set_int (value,
470           onnxClient->getOutputNodeIndex
471           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE));
472       break;
473       break;
474     case PROP_CLASS_NODE_INDEX:
475       g_value_set_int (value,
476           onnxClient->getOutputNodeIndex
477           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS));
478       break;
479     case PROP_INPUT_IMAGE_FORMAT:
480       g_value_set_enum (value, onnxClient->getInputImageFormat ());
481       break;
482     default:
483       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
484       break;
485   }
486 }
487
488 static gboolean
489 gst_onnx_object_detector_create_session (GstBaseTransform * trans)
490 {
491   GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
492   auto onnxClient = GST_ONNX_MEMBER (self);
493
494   GST_OBJECT_LOCK (self);
495   if (self->onnx_disabled || onnxClient->hasSession ()) {
496     GST_OBJECT_UNLOCK (self);
497
498     return TRUE;
499   }
500   if (self->model_file) {
501     gboolean ret = GST_ONNX_MEMBER (self)->createSession (self->model_file,
502         self->optimization_level,
503         self->execution_provider);
504     if (!ret) {
505       GST_ERROR_OBJECT (self,
506           "Unable to create ONNX session. Detection disabled.");
507     } else {
508       auto outputNames = onnxClient->getOutputNodeNames ();
509
510       for (size_t i = 0; i < outputNames.size (); ++i)
511         GST_INFO_OBJECT (self, "Output node index: %d for node: %s", (gint) i,
512             outputNames[i]);
513       if (outputNames.size () < 3) {
514         GST_ERROR_OBJECT (self,
515             "Number of output tensor nodes %d does not match the 3 or 4 nodes "
516             "required for an object detection model. Detection is disabled.",
517             (gint) outputNames.size ());
518         self->onnx_disabled = TRUE;
519       }
520       // sanity check on output node indices
521       if (onnxClient->getOutputNodeIndex
522           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION) ==
523           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
524         GST_ERROR_OBJECT (self,
525             "Output detection node index not set. Detection disabled.");
526         self->onnx_disabled = TRUE;
527       }
528       if (onnxClient->getOutputNodeIndex
529           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX) ==
530           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
531         GST_ERROR_OBJECT (self,
532             "Output bounding box node index not set. Detection disabled.");
533         self->onnx_disabled = TRUE;
534       }
535       if (onnxClient->getOutputNodeIndex
536           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE) ==
537           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
538         GST_ERROR_OBJECT (self,
539             "Output score node index not set. Detection disabled.");
540         self->onnx_disabled = TRUE;
541       }
542       if (outputNames.size () == 4 && onnxClient->getOutputNodeIndex
543           (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS) ==
544           GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
545         GST_ERROR_OBJECT (self,
546             "Output class node index not set. Detection disabled.");
547         self->onnx_disabled = TRUE;
548       }
549           // model is not usable, so fail
550       if (self->onnx_disabled) {
551                   GST_ELEMENT_WARNING (self, RESOURCE, FAILED,
552                           ("ONNX model cannot be used for object detection"), (NULL));
553
554                   return FALSE;
555       }
556     }
557   } else {
558     self->onnx_disabled = TRUE;
559   }
560   GST_OBJECT_UNLOCK (self);
561   if (self->onnx_disabled){
562     gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
563   }
564
565   return TRUE;
566 }
567
568
569 static GstCaps *
570 gst_onnx_object_detector_transform_caps (GstBaseTransform *
571     trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
572 {
573   GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
574   auto onnxClient = GST_ONNX_MEMBER (self);
575   GstCaps *other_caps;
576   guint i;
577
578   if ( !gst_onnx_object_detector_create_session (trans) )
579           return NULL;
580   GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
581
582   if (gst_base_transform_is_passthrough (trans)
583       || (!onnxClient->isFixedInputImageSize ()))
584     return gst_caps_ref (caps);
585
586   other_caps = gst_caps_new_empty ();
587   for (i = 0; i < gst_caps_get_size (caps); ++i) {
588     GstStructure *structure, *new_structure;
589
590     structure = gst_caps_get_structure (caps, i);
591     new_structure = gst_structure_copy (structure);
592     gst_structure_set (new_structure, "width", G_TYPE_INT,
593         onnxClient->getWidth (), "height", G_TYPE_INT,
594         onnxClient->getHeight (), NULL);
595     GST_LOG_OBJECT (self,
596         "transformed structure %2d: %" GST_PTR_FORMAT " => %"
597         GST_PTR_FORMAT, i, structure, new_structure);
598     gst_caps_append_structure (other_caps, new_structure);
599   }
600
601   if (!gst_caps_is_empty (other_caps) && filter_caps) {
602     GstCaps *tmp = gst_caps_intersect_full (other_caps,filter_caps,
603         GST_CAPS_INTERSECT_FIRST);
604     gst_caps_replace (&other_caps, tmp);
605     gst_caps_unref (tmp);
606   }
607
608   return other_caps;
609 }
610
611
612 static GstFlowReturn
613 gst_onnx_object_detector_transform_ip (GstBaseTransform * trans,
614     GstBuffer * buf)
615 {
616   if (!gst_base_transform_is_passthrough (trans)
617       && !gst_onnx_object_detector_process (trans, buf)){
618             GST_ELEMENT_WARNING (trans, STREAM, FAILED,
619           ("ONNX object detection failed"), (NULL));
620             return GST_FLOW_ERROR;
621   }
622
623   return GST_FLOW_OK;
624 }
625
626 static gboolean
627 gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
628 {
629   GstMapInfo info;
630   GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf);
631
632   if (!vmeta) {
633     GST_WARNING_OBJECT (trans, "missing video meta");
634     return FALSE;
635   }
636   if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
637     GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
638     auto boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta,
639         self->label_file ? self->label_file : "",
640         self->score_threshold);
641   for (auto & b:boxes) {
642       auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
643           GST_ONNX_OBJECT_DETECTOR_META_NAME,
644           b.x0, b.y0,
645           b.width,
646           b.height);
647       if (!vroi_meta) {
648         GST_WARNING_OBJECT (trans,
649             "Unable to attach GstVideoRegionOfInterestMeta to buffer");
650         return FALSE;
651       }
652       auto s = gst_structure_new (GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME,
653           GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL,
654           G_TYPE_STRING,
655           b.label.c_str (),
656           GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE,
657           G_TYPE_DOUBLE,
658           b.score,
659           NULL);
660       gst_video_region_of_interest_meta_add_param (vroi_meta, s);
661       GST_DEBUG_OBJECT (self,
662           "Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n",
663           b.label.c_str (), b.score, b.x0, b.y0,
664           b.x0 + b.width, b.y0 + b.height);
665     }
666     gst_buffer_unmap (buf, &info);
667   }
668
669   return TRUE;
670 }