onnx: Update to OnnxRT >= 1.13.1 API
[platform/upstream/gstreamer.git] / subprojects / gst-plugins-bad / ext / onnx / gstonnxclient.h
1 /*
2  * GStreamer gstreamer-onnxclient
3  * Copyright (C) 2021 Collabora Ltd
4  *
5  * gstonnxclient.h
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Library General Public
9  * License as published by the Free Software Foundation; either
10  * version 2 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Library General Public License for more details.
16  *
17  * You should have received a copy of the GNU Library General Public
18  * License along with this library; if not, write to the
19  * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
20  * Boston, MA 02110-1301, USA.
21  */
22 #ifndef __GST_ONNX_CLIENT_H__
23 #define __GST_ONNX_CLIENT_H__
24
25 #include <gst/gst.h>
26 #include <onnxruntime_cxx_api.h>
27 #include <gst/video/video.h>
28 #include "gstonnxelement.h"
29 #include <string>
30 #include <vector>
31
32 namespace GstOnnxNamespace {
33   enum GstMlOutputNodeFunction {
34     GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
35     GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
36     GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
37     GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
38     GST_ML_OUTPUT_NODE_NUMBER_OF,
39   };
40
41   const gint GST_ML_NODE_INDEX_DISABLED = -1;
42
43   struct GstMlOutputNodeInfo {
44     GstMlOutputNodeInfo(void);
45         gint index;
46     ONNXTensorElementDataType type;
47   };
48
49   struct GstMlBoundingBox {
50     GstMlBoundingBox(std::string lbl,
51                      float score,
52                      float _x0,
53                      float _y0,
54                      float _width,
55                      float _height):label(lbl),
56       score(score), x0(_x0), y0(_y0), width(_width), height(_height) {
57     }
58     GstMlBoundingBox():GstMlBoundingBox("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) {
59     }
60     std::string label;
61     float score;
62     float x0;
63     float y0;
64     float width;
65     float height;
66   };
67
68   class GstOnnxClient {
69   public:
70     GstOnnxClient(void);
71     ~GstOnnxClient(void);
72     bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
73                        GstOnnxExecutionProvider provider);
74     bool hasSession(void);
75     void setInputImageFormat(GstMlModelInputImageFormat format);
76     GstMlModelInputImageFormat getInputImageFormat(void);
77     void setOutputNodeIndex(GstMlOutputNodeFunction nodeType, gint index);
78     gint getOutputNodeIndex(GstMlOutputNodeFunction nodeType);
79     void setOutputNodeType(GstMlOutputNodeFunction nodeType,
80                            ONNXTensorElementDataType type);
81     ONNXTensorElementDataType getOutputNodeType(GstMlOutputNodeFunction type);
82     std::string getOutputNodeName(GstMlOutputNodeFunction nodeType);
83     std::vector < GstMlBoundingBox > run(uint8_t * img_data,
84                                           GstVideoMeta * vmeta,
85                                           std::string labelPath,
86                                           float scoreThreshold);
87     std::vector < GstMlBoundingBox > &getBoundingBoxes(void);
88     std::vector < const char *>getOutputNodeNames(void);
89     bool isFixedInputImageSize(void);
90     int32_t getWidth(void);
91     int32_t getHeight(void);
92   private:
93     void parseDimensions(GstVideoMeta * vmeta);
94     template < typename T > std::vector < GstMlBoundingBox >
95     doRun(uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath,
96             float scoreThreshold);
97     std::vector < std::string > ReadLabels(const std::string & labelsFile);
98     Ort::Env & getEnv(void);
99     Ort::Session * session;
100     int32_t width;
101     int32_t height;
102     int32_t channels;
103     uint8_t *dest;
104     GstOnnxExecutionProvider m_provider;
105     std::vector < Ort::Value > modelOutput;
106     std::vector < std::string > labels;
107     // !! indexed by function
108     GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
109     // !! indexed by array index
110         size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
111     std::vector < const char *> outputNamesRaw;
112     std::vector < Ort::AllocatedStringPtr > outputNames;
113     GstMlModelInputImageFormat inputImageFormat;
114     bool fixedInputImageSize;
115   };
116 }
117
118 #endif                          /* __GST_ONNX_CLIENT_H__ */