2 * GStreamer gstreamer-onnxclient
3 * Copyright (C) 2021 Collabora Ltd
7 * This library is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU Library General Public
9 * License as published by the Free Software Foundation; either
10 * version 2 of the License, or (at your option) any later version.
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 * Library General Public License for more details.
17 * You should have received a copy of the GNU Library General Public
18 * License along with this library; if not, write to the
19 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
20 * Boston, MA 02110-1301, USA.
22 #ifndef __GST_ONNX_CLIENT_H__
23 #define __GST_ONNX_CLIENT_H__
26 #include <onnxruntime_cxx_api.h>
27 #include <gst/video/video.h>
28 #include "gstonnxelement.h"
32 namespace GstOnnxNamespace {
33 enum GstMlOutputNodeFunction {
34 GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
35 GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
36 GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
37 GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
38 GST_ML_OUTPUT_NODE_NUMBER_OF,
41 const gint GST_ML_NODE_INDEX_DISABLED = -1;
43 struct GstMlOutputNodeInfo {
44 GstMlOutputNodeInfo(void);
46 ONNXTensorElementDataType type;
49 struct GstMlBoundingBox {
50 GstMlBoundingBox(std::string lbl,
55 float _height):label(lbl),
56 score(score), x0(_x0), y0(_y0), width(_width), height(_height) {
58 GstMlBoundingBox():GstMlBoundingBox("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) {
72 bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
73 GstOnnxExecutionProvider provider);
74 bool hasSession(void);
75 void setInputImageFormat(GstMlModelInputImageFormat format);
76 GstMlModelInputImageFormat getInputImageFormat(void);
77 void setOutputNodeIndex(GstMlOutputNodeFunction nodeType, gint index);
78 gint getOutputNodeIndex(GstMlOutputNodeFunction nodeType);
79 void setOutputNodeType(GstMlOutputNodeFunction nodeType,
80 ONNXTensorElementDataType type);
81 ONNXTensorElementDataType getOutputNodeType(GstMlOutputNodeFunction type);
82 std::string getOutputNodeName(GstMlOutputNodeFunction nodeType);
83 std::vector < GstMlBoundingBox > run(uint8_t * img_data,
85 std::string labelPath,
86 float scoreThreshold);
87 std::vector < GstMlBoundingBox > &getBoundingBoxes(void);
88 std::vector < const char *>getOutputNodeNames(void);
89 bool isFixedInputImageSize(void);
90 int32_t getWidth(void);
91 int32_t getHeight(void);
93 void parseDimensions(GstVideoMeta * vmeta);
94 template < typename T > std::vector < GstMlBoundingBox >
95 doRun(uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath,
96 float scoreThreshold);
97 std::vector < std::string > ReadLabels(const std::string & labelsFile);
98 Ort::Env & getEnv(void);
99 Ort::Session * session;
104 GstOnnxExecutionProvider m_provider;
105 std::vector < Ort::Value > modelOutput;
106 std::vector < std::string > labels;
107 // !! indexed by function
108 GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
109 // !! indexed by array index
110 size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
111 std::vector < const char *> outputNamesRaw;
112 std::vector < Ort::AllocatedStringPtr > outputNames;
113 GstMlModelInputImageFormat inputImageFormat;
114 bool fixedInputImageSize;
118 #endif /* __GST_ONNX_CLIENT_H__ */