onnx: Update to OnnxRT >= 1.13.1 API
[platform/upstream/gstreamer.git] / subprojects / gst-plugins-bad / ext / onnx / gstonnxclient.cpp
1 /*
2  * GStreamer gstreamer-onnxclient
3  * Copyright (C) 2021 Collabora Ltd
4  *
5  * gstonnxclient.cpp
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Library General Public
9  * License as published by the Free Software Foundation; either
10  * version 2 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Library General Public License for more details.
16  *
17  * You should have received a copy of the GNU Library General Public
18  * License along with this library; if not, write to the
19  * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
20  * Boston, MA 02110-1301, USA.
21  */
22
23 #include "gstonnxclient.h"
24 #include <providers/cpu/cpu_provider_factory.h>
25 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
26 #include <providers/cuda/cuda_provider_factory.h>
27 #endif
28 #include <exception>
29 #include <fstream>
30 #include <iostream>
31 #include <limits>
32 #include <numeric>
33 #include <cmath>
34 #include <sstream>
35
36 namespace GstOnnxNamespace
37 {
38 template < typename T >
39     std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
40 {
41     os << "[";
42     for (size_t i = 0; i < v.size (); ++i)
43     {
44       os << v[i];
45       if (i != v.size () - 1)
46       {
47         os << ", ";
48       }
49     }
50     os << "]";
51
52     return os;
53 }
54
55 GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index
56   (GST_ML_NODE_INDEX_DISABLED),
57   type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
58 {
59 }
60
61 GstOnnxClient::GstOnnxClient ():session (nullptr),
62       width (0),
63       height (0),
64       channels (0),
65       dest (nullptr),
66       m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
67       inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC),
68       fixedInputImageSize (true)
69 {
70     for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i)
71       outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i;
72 }
73
74 GstOnnxClient::~GstOnnxClient ()
75 {
76     outputNames.clear();
77     delete session;
78     delete[]dest;
79 }
80
81 Ort::Env & GstOnnxClient::getEnv (void)
82 {
83     static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
84         "GstOnnxNamespace");
85
86     return env;
87 }
88
89 int32_t GstOnnxClient::getWidth (void)
90 {
91     return width;
92 }
93
94 int32_t GstOnnxClient::getHeight (void)
95 {
96     return height;
97 }
98
99 bool GstOnnxClient::isFixedInputImageSize (void)
100 {
101     return fixedInputImageSize;
102 }
103
104 std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
105 {
106     switch (nodeType) {
107       case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION:
108         return "detection";
109         break;
110       case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX:
111         return "bounding box";
112         break;
113       case GST_ML_OUTPUT_NODE_FUNCTION_SCORE:
114         return "score";
115         break;
116       case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
117         return "label";
118         break;
119       case GST_ML_OUTPUT_NODE_NUMBER_OF:
120         g_assert_not_reached();
121         GST_WARNING("Invalid parameter");
122         break;
123     };
124
125     return "";
126 }
127
128 void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format)
129 {
130     inputImageFormat = format;
131 }
132
133 GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
134 {
135     return inputImageFormat;
136 }
137
138 std::vector< const char *> GstOnnxClient::getOutputNodeNames (void)
139 {
140     if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) {
141         outputNamesRaw.resize(outputNames.size());
142         for (size_t i = 0; i < outputNamesRaw.size(); i++) {
143           outputNamesRaw[i] = outputNames[i].get();
144         }
145     }
146
147     return outputNamesRaw;
148 }
149
150 void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
151       gint index)
152 {
153     g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF);
154     outputNodeInfo[node].index = index;
155     if (index != GST_ML_NODE_INDEX_DISABLED)
156       outputNodeIndexToFunction[index] = node;
157 }
158
159 gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node)
160 {
161     return outputNodeInfo[node].index;
162 }
163
164 void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node,
165       ONNXTensorElementDataType type)
166 {
167     outputNodeInfo[node].type = type;
168 }
169
170 ONNXTensorElementDataType
171       GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node)
172 {
173     return outputNodeInfo[node].type;
174 }
175
176 bool GstOnnxClient::hasSession (void)
177 {
178     return session != nullptr;
179 }
180
181 bool GstOnnxClient::createSession (std::string modelFile,
182       GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
183 {
184     if (session)
185       return true;
186
187     GraphOptimizationLevel onnx_optim;
188     switch (optim) {
189       case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL:
190         onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL;
191         break;
192       case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC:
193         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_BASIC;
194         break;
195       case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED:
196         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
197         break;
198       case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL:
199         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_ALL;
200         break;
201       default:
202         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
203         break;
204     };
205
206     Ort::SessionOptions sessionOptions;
207     // for debugging
208     //sessionOptions.SetIntraOpNumThreads (1);
209     sessionOptions.SetGraphOptimizationLevel (onnx_optim);
210     m_provider = provider;
211     switch (m_provider) {
212       case GST_ONNX_EXECUTION_PROVIDER_CUDA:
213 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
214         Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA
215             (sessionOptions, 0));
216 #else
217         return false;
218 #endif
219         break;
220       default:
221         break;
222
223     };
224     session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions);
225     auto inputTypeInfo = session->GetInputTypeInfo (0);
226     std::vector < int64_t > inputDims =
227         inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
228     if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
229       height = inputDims[1];
230       width = inputDims[2];
231       channels = inputDims[3];
232     } else {
233       channels = inputDims[1];
234       height = inputDims[2];
235       width = inputDims[3];
236     }
237
238     fixedInputImageSize = width > 0 && height > 0;
239     GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
240
241     Ort::AllocatorWithDefaultOptions allocator;
242     auto input_name = session->GetInputNameAllocated (0, allocator);
243     GST_DEBUG ("Input name: %s", input_name.get());
244
245     for (size_t i = 0; i < session->GetOutputCount (); ++i) {
246       auto output_name = session->GetOutputNameAllocated (i, allocator);
247       GST_DEBUG("Output name %lu:%s", i, output_name.get());
248       outputNames.push_back (std::move(output_name));
249       auto type_info = session->GetOutputTypeInfo (i);
250       auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
251
252       if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) {
253         auto function = outputNodeIndexToFunction[i];
254         outputNodeInfo[function].type = tensor_info.GetElementType ();
255       }
256     }
257
258     return true;
259 }
260
261 std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data,
262       GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold)
263 {
264     auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS);
265     return (type ==
266         ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ?
267           doRun < float >(img_data, vmeta, labelPath, scoreThreshold)
268             : doRun < int >(img_data, vmeta, labelPath, scoreThreshold);
269 }
270
271 void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta)
272 {
273     int32_t newWidth = fixedInputImageSize ? width : vmeta->width;
274     int32_t newHeight = fixedInputImageSize ? height : vmeta->height;
275
276     if (!dest || width * height < newWidth * newHeight) {
277       delete[] dest;
278       dest = new uint8_t[newWidth * newHeight * channels];
279     }
280     width = newWidth;
281     height = newHeight;
282 }
283
284 template < typename T > std::vector < GstMlBoundingBox >
285       GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta,
286       std::string labelPath, float scoreThreshold)
287 {
288     std::vector < GstMlBoundingBox > boundingBoxes;
289     if (!img_data)
290       return boundingBoxes;
291
292     parseDimensions (vmeta);
293
294     Ort::AllocatorWithDefaultOptions allocator;
295     auto inputName = session->GetInputNameAllocated (0, allocator);
296     auto inputTypeInfo = session->GetInputTypeInfo (0);
297     std::vector < int64_t > inputDims =
298         inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
299     inputDims[0] = 1;
300     if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
301       inputDims[1] = height;
302       inputDims[2] = width;
303     } else {
304       inputDims[2] = height;
305       inputDims[3] = width;
306     }
307
308     std::ostringstream buffer;
309     buffer << inputDims;
310     GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ());
311
312     // copy video frame
313     uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
314     uint32_t srcSamplesPerPixel = 3;
315     switch (vmeta->format) {
316       case GST_VIDEO_FORMAT_RGBA:
317         srcSamplesPerPixel = 4;
318         break;
319       case GST_VIDEO_FORMAT_BGRA:
320         srcSamplesPerPixel = 4;
321         srcPtr[0] = img_data + 2;
322         srcPtr[1] = img_data + 1;
323         srcPtr[2] = img_data + 0;
324         break;
325       case GST_VIDEO_FORMAT_ARGB:
326         srcSamplesPerPixel = 4;
327         srcPtr[0] = img_data + 1;
328         srcPtr[1] = img_data + 2;
329         srcPtr[2] = img_data + 3;
330         break;
331       case GST_VIDEO_FORMAT_ABGR:
332         srcSamplesPerPixel = 4;
333         srcPtr[0] = img_data + 3;
334         srcPtr[1] = img_data + 2;
335         srcPtr[2] = img_data + 1;
336         break;
337       case GST_VIDEO_FORMAT_BGR:
338         srcPtr[0] = img_data + 2;
339         srcPtr[1] = img_data + 1;
340         srcPtr[2] = img_data + 0;
341         break;
342       default:
343         break;
344     }
345     size_t destIndex = 0;
346     uint32_t stride = vmeta->stride[0];
347     if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
348       for (int32_t j = 0; j < height; ++j) {
349         for (int32_t i = 0; i < width; ++i) {
350           for (int32_t k = 0; k < channels; ++k) {
351             dest[destIndex++] = *srcPtr[k];
352             srcPtr[k] += srcSamplesPerPixel;
353           }
354         }
355         // correct for stride
356         for (uint32_t k = 0; k < 3; ++k)
357           srcPtr[k] += stride - srcSamplesPerPixel * width;
358       }
359     } else {
360       size_t frameSize = width * height;
361       uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize };
362       for (int32_t j = 0; j < height; ++j) {
363         for (int32_t i = 0; i < width; ++i) {
364           for (int32_t k = 0; k < channels; ++k) {
365             destPtr[k][destIndex] = *srcPtr[k];
366             srcPtr[k] += srcSamplesPerPixel;
367           }
368           destIndex++;
369         }
370         // correct for stride
371         for (uint32_t k = 0; k < 3; ++k)
372           srcPtr[k] += stride - srcSamplesPerPixel * width;
373       }
374     }
375
376     const size_t inputTensorSize = width * height * channels;
377     auto memoryInfo =
378         Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator,
379         OrtMemType::OrtMemTypeDefault);
380     std::vector < Ort::Value > inputTensors;
381     inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
382             dest, inputTensorSize, inputDims.data (), inputDims.size ()));
383     std::vector < const char *>inputNames { inputName.get () };
384
385     std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
386         inputNames.data (),
387         inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ());
388
389     auto numDetections =
390         modelOutput[getOutputNodeIndex
391         (GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >();
392     auto bboxes =
393         modelOutput[getOutputNodeIndex
394         (GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >();
395     auto scores =
396         modelOutput[getOutputNodeIndex
397         (GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >();
398     T *labelIndex = nullptr;
399     if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) !=
400         GST_ML_NODE_INDEX_DISABLED) {
401       labelIndex =
402           modelOutput[getOutputNodeIndex
403           (GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > ();
404     }
405     if (labels.empty () && !labelPath.empty ())
406       labels = ReadLabels (labelPath);
407
408     for (int i = 0; i < numDetections[0]; ++i) {
409       if (scores[i] > scoreThreshold) {
410         std::string label = "";
411
412         if (labelIndex && !labels.empty ())
413           label = labels[labelIndex[i] - 1];
414         auto score = scores[i];
415         auto y0 = bboxes[i * 4] * height;
416         auto x0 = bboxes[i * 4 + 1] * width;
417         auto bheight = bboxes[i * 4 + 2] * height - y0;
418         auto bwidth = bboxes[i * 4 + 3] * width - x0;
419         boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
420                 bheight));
421       }
422     }
423     return boundingBoxes;
424 }
425
426 std::vector < std::string >
427     GstOnnxClient::ReadLabels (const std::string & labelsFile)
428 {
429     std::vector < std::string > labels;
430     std::string line;
431     std::ifstream fp (labelsFile);
432     while (std::getline (fp, line))
433       labels.push_back (line);
434
435     return labels;
436   }
437 }