onnx: Update to OnnxRT >= 1.13.1 API
authorDaniel Morin <daniel.morin@collabora.com>
Thu, 10 Nov 2022 13:50:35 +0000 (08:50 -0500)
committerGStreamer Marge Bot <gitlab-merge-bot@gstreamer-foundation.org>
Tue, 22 Nov 2022 22:36:34 +0000 (22:36 +0000)
- Replace deprecated methods
- Add a check on ORT version we are compatible with.
- Add clarification to the example given.
- Add the url to retrieve the model mentioned in the example.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/3388>

subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp
subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h
subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp
subprojects/gst-plugins-bad/ext/onnx/meson.build

index f47abf1..a8600d2 100644 (file)
@@ -73,6 +73,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
 
 GstOnnxClient::~GstOnnxClient ()
 {
+    outputNames.clear();
     delete session;
     delete[]dest;
 }
@@ -115,6 +116,10 @@ std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
       case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
         return "label";
         break;
+      case GST_ML_OUTPUT_NODE_NUMBER_OF:
+        g_assert_not_reached();
+        GST_WARNING("Invalid parameter");
+        break;
     };
 
     return "";
@@ -130,9 +135,16 @@ GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
     return inputImageFormat;
 }
 
-std::vector < const char *>GstOnnxClient::getOutputNodeNames (void)
+std::vector< const char *> GstOnnxClient::getOutputNodeNames (void)
 {
-    return outputNames;
+    if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) {
+        outputNamesRaw.resize(outputNames.size());
+        for (size_t i = 0; i < outputNamesRaw.size(); i++) {
+          outputNamesRaw[i] = outputNames[i].get();
+        }
+    }
+
+    return outputNamesRaw;
 }
 
 void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
@@ -227,11 +239,13 @@ bool GstOnnxClient::createSession (std::string modelFile,
     GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
 
     Ort::AllocatorWithDefaultOptions allocator;
-    GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator));
+    auto input_name = session->GetInputNameAllocated (0, allocator);
+    GST_DEBUG ("Input name: %s", input_name.get());
 
     for (size_t i = 0; i < session->GetOutputCount (); ++i) {
-      auto output_name = session->GetOutputName (i, allocator);
-      outputNames.push_back (output_name);
+      auto output_name = session->GetOutputNameAllocated (i, allocator);
+      GST_DEBUG("Output name %lu:%s", i, output_name.get());
+      outputNames.push_back (std::move(output_name));
       auto type_info = session->GetOutputTypeInfo (i);
       auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
 
@@ -278,7 +292,7 @@ template < typename T > std::vector < GstMlBoundingBox >
     parseDimensions (vmeta);
 
     Ort::AllocatorWithDefaultOptions allocator;
-    auto inputName = session->GetInputName (0, allocator);
+    auto inputName = session->GetInputNameAllocated (0, allocator);
     auto inputTypeInfo = session->GetInputTypeInfo (0);
     std::vector < int64_t > inputDims =
         inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
@@ -366,11 +380,11 @@ template < typename T > std::vector < GstMlBoundingBox >
     std::vector < Ort::Value > inputTensors;
     inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
             dest, inputTensorSize, inputDims.data (), inputDims.size ()));
-    std::vector < const char *>inputNames { inputName };
+    std::vector < const char *>inputNames { inputName.get () };
 
     std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
         inputNames.data (),
-        inputTensors.data (), 1, outputNames.data (), outputNames.size ());
+        inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ());
 
     auto numDetections =
         modelOutput[getOutputNodeIndex
index 769cd11..edbc2f4 100644 (file)
@@ -108,7 +108,8 @@ namespace GstOnnxNamespace {
     GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
     // !! indexed by array index
        size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
-    std::vector < const char *>outputNames;
+    std::vector < const char *> outputNamesRaw;
+    std::vector < Ort::AllocatedStringPtr > outputNames;
     GstMlModelInputImageFormat inputImageFormat;
     bool fixedInputImageSize;
   };
index 28f4cf2..680b02f 100644 (file)
  *
  * ## Example launch command:
  *
- * (note: an object detection model has 3 or 4 output nodes, but there is no naming convention
- * to indicate which node outputs the bounding box, which node outputs the label, etc.
- * So, the `onnxobjectdetector` element has properties to map each node's functionality to its
- * respective node index in the specified model )
+ * (note: an object detection model has 3 or 4 output nodes, but there is no
+ * naming convention to indicate which node outputs the bounding box, which
+ * node outputs the label, etc. So, the `onnxobjectdetector` element has
+ * properties to map each node's functionality to its respective node index in
+ * the specified model. Image resolution also need to be adapted to the model.
+ * The videoscale in the pipeline below will scale the image, using padding if
+ * required, to 640x383 resolution required by the model.)
+ *
+ * model.onnx can be found here:
+ * https://github.com/zoq/onnx-runtime-examples/raw/main/data/models/model.onnx
  *
  * ```
  * GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
  * location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
  * videoconvert ! \
- * onnxobjectdetector \
+ * videoscale ! \
+ * 'video/x-raw,width=640,height=383' ! \
+ * onnxobjectdetector ! \
  * box-node-index=0 \
  * class-node-index=1 \
  * score-node-index=2 \
index ff91739..e66d649 100644 (file)
@@ -3,7 +3,7 @@ if get_option('onnx').disabled()
 endif
 
 
-onnxrt_dep = dependency('libonnxruntime',required : get_option('onnx'))
+onnxrt_dep = dependency('libonnxruntime', version : '>= 1.13.1', required : get_option('onnx'))
 
 if onnxrt_dep.found()
        onnxrt_include_root = onnxrt_dep.get_variable('includedir')