2 * GStreamer gstreamer-onnxclient
3 * Copyright (C) 2021 Collabora Ltd
7 * This library is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU Library General Public
9 * License as published by the Free Software Foundation; either
10 * version 2 of the License, or (at your option) any later version.
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 * Library General Public License for more details.
17 * You should have received a copy of the GNU Library General Public
18 * License along with this library; if not, write to the
19 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
20 * Boston, MA 02110-1301, USA.
23 #include "gstonnxclient.h"
24 #include <providers/cpu/cpu_provider_factory.h>
25 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
26 #include <providers/cuda/cuda_provider_factory.h>
36 namespace GstOnnxNamespace
38 template < typename T >
39 std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
42 for (size_t i = 0; i < v.size (); ++i)
45 if (i != v.size () - 1)
55 GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index
56 (GST_ML_NODE_INDEX_DISABLED),
57 type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
61 GstOnnxClient::GstOnnxClient ():session (nullptr),
66 m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
67 inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC),
68 fixedInputImageSize (true)
70 for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i)
71 outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i;
74 GstOnnxClient::~GstOnnxClient ()
81 Ort::Env & GstOnnxClient::getEnv (void)
83 static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
89 int32_t GstOnnxClient::getWidth (void)
94 int32_t GstOnnxClient::getHeight (void)
99 bool GstOnnxClient::isFixedInputImageSize (void)
101 return fixedInputImageSize;
104 std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
107 case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION:
110 case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX:
111 return "bounding box";
113 case GST_ML_OUTPUT_NODE_FUNCTION_SCORE:
116 case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
119 case GST_ML_OUTPUT_NODE_NUMBER_OF:
120 g_assert_not_reached();
121 GST_WARNING("Invalid parameter");
128 void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format)
130 inputImageFormat = format;
133 GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
135 return inputImageFormat;
138 std::vector< const char *> GstOnnxClient::getOutputNodeNames (void)
140 if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) {
141 outputNamesRaw.resize(outputNames.size());
142 for (size_t i = 0; i < outputNamesRaw.size(); i++) {
143 outputNamesRaw[i] = outputNames[i].get();
147 return outputNamesRaw;
150 void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
153 g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF);
154 outputNodeInfo[node].index = index;
155 if (index != GST_ML_NODE_INDEX_DISABLED)
156 outputNodeIndexToFunction[index] = node;
159 gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node)
161 return outputNodeInfo[node].index;
164 void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node,
165 ONNXTensorElementDataType type)
167 outputNodeInfo[node].type = type;
170 ONNXTensorElementDataType
171 GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node)
173 return outputNodeInfo[node].type;
176 bool GstOnnxClient::hasSession (void)
178 return session != nullptr;
181 bool GstOnnxClient::createSession (std::string modelFile,
182 GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
187 GraphOptimizationLevel onnx_optim;
189 case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL:
190 onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL;
192 case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC:
193 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_BASIC;
195 case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED:
196 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
198 case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL:
199 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_ALL;
202 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
206 Ort::SessionOptions sessionOptions;
208 //sessionOptions.SetIntraOpNumThreads (1);
209 sessionOptions.SetGraphOptimizationLevel (onnx_optim);
210 m_provider = provider;
211 switch (m_provider) {
212 case GST_ONNX_EXECUTION_PROVIDER_CUDA:
213 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
214 Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA
215 (sessionOptions, 0));
224 session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions);
225 auto inputTypeInfo = session->GetInputTypeInfo (0);
226 std::vector < int64_t > inputDims =
227 inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
228 if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
229 height = inputDims[1];
230 width = inputDims[2];
231 channels = inputDims[3];
233 channels = inputDims[1];
234 height = inputDims[2];
235 width = inputDims[3];
238 fixedInputImageSize = width > 0 && height > 0;
239 GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
241 Ort::AllocatorWithDefaultOptions allocator;
242 auto input_name = session->GetInputNameAllocated (0, allocator);
243 GST_DEBUG ("Input name: %s", input_name.get());
245 for (size_t i = 0; i < session->GetOutputCount (); ++i) {
246 auto output_name = session->GetOutputNameAllocated (i, allocator);
247 GST_DEBUG("Output name %lu:%s", i, output_name.get());
248 outputNames.push_back (std::move(output_name));
249 auto type_info = session->GetOutputTypeInfo (i);
250 auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
252 if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) {
253 auto function = outputNodeIndexToFunction[i];
254 outputNodeInfo[function].type = tensor_info.GetElementType ();
261 std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data,
262 GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold)
264 auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS);
266 ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ?
267 doRun < float >(img_data, vmeta, labelPath, scoreThreshold)
268 : doRun < int >(img_data, vmeta, labelPath, scoreThreshold);
271 void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta)
273 int32_t newWidth = fixedInputImageSize ? width : vmeta->width;
274 int32_t newHeight = fixedInputImageSize ? height : vmeta->height;
276 if (!dest || width * height < newWidth * newHeight) {
278 dest = new uint8_t[newWidth * newHeight * channels];
284 template < typename T > std::vector < GstMlBoundingBox >
285 GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta,
286 std::string labelPath, float scoreThreshold)
288 std::vector < GstMlBoundingBox > boundingBoxes;
290 return boundingBoxes;
292 parseDimensions (vmeta);
294 Ort::AllocatorWithDefaultOptions allocator;
295 auto inputName = session->GetInputNameAllocated (0, allocator);
296 auto inputTypeInfo = session->GetInputTypeInfo (0);
297 std::vector < int64_t > inputDims =
298 inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
300 if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
301 inputDims[1] = height;
302 inputDims[2] = width;
304 inputDims[2] = height;
305 inputDims[3] = width;
308 std::ostringstream buffer;
310 GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ());
313 uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
314 uint32_t srcSamplesPerPixel = 3;
315 switch (vmeta->format) {
316 case GST_VIDEO_FORMAT_RGBA:
317 srcSamplesPerPixel = 4;
319 case GST_VIDEO_FORMAT_BGRA:
320 srcSamplesPerPixel = 4;
321 srcPtr[0] = img_data + 2;
322 srcPtr[1] = img_data + 1;
323 srcPtr[2] = img_data + 0;
325 case GST_VIDEO_FORMAT_ARGB:
326 srcSamplesPerPixel = 4;
327 srcPtr[0] = img_data + 1;
328 srcPtr[1] = img_data + 2;
329 srcPtr[2] = img_data + 3;
331 case GST_VIDEO_FORMAT_ABGR:
332 srcSamplesPerPixel = 4;
333 srcPtr[0] = img_data + 3;
334 srcPtr[1] = img_data + 2;
335 srcPtr[2] = img_data + 1;
337 case GST_VIDEO_FORMAT_BGR:
338 srcPtr[0] = img_data + 2;
339 srcPtr[1] = img_data + 1;
340 srcPtr[2] = img_data + 0;
345 size_t destIndex = 0;
346 uint32_t stride = vmeta->stride[0];
347 if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
348 for (int32_t j = 0; j < height; ++j) {
349 for (int32_t i = 0; i < width; ++i) {
350 for (int32_t k = 0; k < channels; ++k) {
351 dest[destIndex++] = *srcPtr[k];
352 srcPtr[k] += srcSamplesPerPixel;
355 // correct for stride
356 for (uint32_t k = 0; k < 3; ++k)
357 srcPtr[k] += stride - srcSamplesPerPixel * width;
360 size_t frameSize = width * height;
361 uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize };
362 for (int32_t j = 0; j < height; ++j) {
363 for (int32_t i = 0; i < width; ++i) {
364 for (int32_t k = 0; k < channels; ++k) {
365 destPtr[k][destIndex] = *srcPtr[k];
366 srcPtr[k] += srcSamplesPerPixel;
370 // correct for stride
371 for (uint32_t k = 0; k < 3; ++k)
372 srcPtr[k] += stride - srcSamplesPerPixel * width;
376 const size_t inputTensorSize = width * height * channels;
378 Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator,
379 OrtMemType::OrtMemTypeDefault);
380 std::vector < Ort::Value > inputTensors;
381 inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
382 dest, inputTensorSize, inputDims.data (), inputDims.size ()));
383 std::vector < const char *>inputNames { inputName.get () };
385 std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
387 inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ());
390 modelOutput[getOutputNodeIndex
391 (GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >();
393 modelOutput[getOutputNodeIndex
394 (GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >();
396 modelOutput[getOutputNodeIndex
397 (GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >();
398 T *labelIndex = nullptr;
399 if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) !=
400 GST_ML_NODE_INDEX_DISABLED) {
402 modelOutput[getOutputNodeIndex
403 (GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > ();
405 if (labels.empty () && !labelPath.empty ())
406 labels = ReadLabels (labelPath);
408 for (int i = 0; i < numDetections[0]; ++i) {
409 if (scores[i] > scoreThreshold) {
410 std::string label = "";
412 if (labelIndex && !labels.empty ())
413 label = labels[labelIndex[i] - 1];
414 auto score = scores[i];
415 auto y0 = bboxes[i * 4] * height;
416 auto x0 = bboxes[i * 4 + 1] * width;
417 auto bheight = bboxes[i * 4 + 2] * height - y0;
418 auto bwidth = bboxes[i * 4 + 3] * width - x0;
419 boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
423 return boundingBoxes;
426 std::vector < std::string >
427 GstOnnxClient::ReadLabels (const std::string & labelsFile)
429 std::vector < std::string > labels;
431 std::ifstream fp (labelsFile);
432 while (std::getline (fp, line))
433 labels.push_back (line);