Add support for TensorRT 10.
authorBram Veldhoen <bram@adiplus.nl>
Sat, 15 Jun 2024 15:30:35 +0000 (17:30 +0200)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Tue, 25 Jun 2024 00:38:21 +0000 (09:38 +0900)
Adds a tensor_filter (tensorrt10) for TensorRT 10+ using the onnx and
engine parsers. Leaves as-is the tensor_filter (tensorrt) for TensorRT
10- using the uffparsers.

Signed-off-by: Bram Veldhoen <noreply@github.com>
ext/nnstreamer/tensor_filter/meson.build
ext/nnstreamer/tensor_filter/tensor_filter_tensorrt10.cc [new file with mode: 0644]
meson.build
meson_options.txt
tests/nnstreamer_filter_tensorrt10/runTest.sh [new file with mode: 0755]
tests/test_models/labels/coco.txt [new file with mode: 0644]
tools/cuda-version.sh [new file with mode: 0755]
tools/tensorrt-version.sh [new file with mode: 0755]

index c8e0028..c87acc7 100644 (file)
@@ -683,6 +683,26 @@ if tensorrt_support_is_available
   )
 endif
 
+if tensorrt10_support_is_available
+  filter_sub_tensorrt10_sources = ['tensor_filter_tensorrt10.cc']
+
+  nnstreamer_filter_tensorrt10_deps = [glib_dep, nnstreamer_single_dep, tensorrt10_support_deps]
+
+  shared_library('nnstreamer_filter_tensorrt10',
+    filter_sub_tensorrt10_sources,
+    dependencies: nnstreamer_filter_tensorrt10_deps,
+    install: true,
+    install_dir: filter_subplugin_install_dir
+  )
+
+  static_library('nnstreamer_filter_tensorrt10',
+    filter_sub_tensorrt10_sources,
+    dependencies: nnstreamer_filter_tensorrt10_deps,
+    install: true,
+    install_dir: nnstreamer_libdir
+  )
+endif
+
 if lua_support_is_available
   if lua_support_deps[0].version().version_compare('>=5.3')
     message ('tensor-filter::lua does not support Lua >= 5.3, yet. Fix #3531 first.')
diff --git a/ext/nnstreamer/tensor_filter/tensor_filter_tensorrt10.cc b/ext/nnstreamer/tensor_filter/tensor_filter_tensorrt10.cc
new file mode 100644 (file)
index 0000000..8bdb130
--- /dev/null
@@ -0,0 +1,775 @@
+/* SPDX-License-Identifier: LGPL-2.1-only */
+/**
+ * GStreamer Tensor_Filter, TensorRT Module
+ * Copyright (C) 2024 Bram Veldhoen
+ */
+/**
+ * @file   tensor_filter_tensorrt10.cc
+ * @date   Jun 2024
+ * @brief  TensorRT 10+ module for tensor_filter gstreamer plugin
+ * @see    http://github.com/nnstreamer/nnstreamer
+ * @see    https://github.com/NVIDIA/TensorRT
+ * @author Bram Veldhoen
+ * @bug    No known bugs except for NYI items
+ *
+ * This is the per-NN-framework plugin (TensorRT) for tensor_filter.
+ *
+ * @note Supports onnxruntime .onnx and tensorrt .engine file as inference model formats.
+ *   When an .onnx file is provided, this plugin will generate the tensorrt .engine file,
+ *   and store it in /tmp/<modelfilename>.engine.
+ *
+ * @todo:
+ *  - Add option parameter for device_id.
+ *  - Add option parameter for generated .engine file (now default in /tmp).
+ *  - Add support for model builder parameters.
+ *  - Add support for optimization profile(s).
+ *  - Add support for batch_size > 1 (to allow for i.e. multiple camera streams).
+ */
+
+#include <algorithm>
+#include <filesystem>
+#include <fstream>
+#include <memory>
+#include <stdexcept>
+#include <vector>
+
+#include <nnstreamer_cppplugin_api_filter.hh>
+#include <nnstreamer_log.h>
+#include <nnstreamer_plugin_api_util.h>
+#include <nnstreamer_util.h>
+
+#include <NvInfer.h>
+#include <NvOnnxParser.h>
+#include <cuda_runtime_api.h>
+
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+
+using Severity = nvinfer1::ILogger::Severity;
+
+/** @brief a global object of ILogger */
+class Logger : public nvinfer1::ILogger
+{
+  void log (Severity severity, const char *msg) noexcept override
+  {
+    switch (severity) {
+      case Severity::kWARNING:
+        ml_logw ("NVINFER: %s", msg);
+        break;
+      case Severity::kINFO:
+        ml_logi ("NVINFER: %s", msg);
+        break;
+      case Severity::kVERBOSE:
+        ml_logd ("NVINFER: %s", msg);
+        break;
+      default:
+        ml_loge ("NVINFER: %s", msg);
+        break;
+    }
+  }
+} gLogger;
+
+using nnstreamer::tensor_filter_subplugin;
+
+namespace nnstreamer
+{
+namespace tensorfilter_tensorrt10
+{
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+void _init_filter_tensorrt10 (void) __attribute__ ((constructor));
+void _fini_filter_tensorrt10 (void) __attribute__ ((destructor));
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+/** @brief Deleter for instances from the TensorRT library. */
+struct NvInferDeleter {
+  template <typename T> void operator() (T *obj) const
+  {
+    delete obj;
+  }
+};
+
+template <typename T>
+using NvInferUniquePtr = std::unique_ptr<T, NvInferDeleter>;
+
+template <typename T>
+NvInferUniquePtr<T>
+makeNvInferUniquePtr (T *t)
+{
+  return NvInferUniquePtr<T>{ t };
+}
+
+/** @brief Holds metadata related to a TensorRT tensor. */
+struct NvInferTensorInfo {
+  const char *name;
+  nvinfer1::TensorIOMode mode;
+  nvinfer1::Dims shape;
+  nvinfer1::DataType dtype;
+  std::size_t dtype_size;
+  std::size_t volume;
+  std::size_t buffer_size;
+  void *buffer; /**< Cuda buffer */
+};
+
+/** @brief tensorrt10 subplugin class */
+class tensorrt10_subplugin final : public tensor_filter_subplugin
+{
+  template <typename T> using UniquePtr = std::unique_ptr<T>;
+
+  public:
+  static void init_filter_tensorrt10 ();
+  static void fini_filter_tensorrt10 ();
+
+  tensorrt10_subplugin ();
+  ~tensorrt10_subplugin ();
+
+  tensor_filter_subplugin &getEmptyInstance ();
+  void configure_instance (const GstTensorFilterProperties *prop);
+  void invoke (const GstTensorMemory *input, GstTensorMemory *output);
+  void getFrameworkInfo (GstTensorFilterFrameworkInfo &info);
+  int getModelInfo (model_info_ops ops, GstTensorsInfo &in_info, GstTensorsInfo &out_info);
+  int eventHandler (event_ops ops, GstTensorFilterFrameworkEventData &data);
+
+
+  private:
+  static const char *name;
+  static const accl_hw hw_list[];
+  static const int num_hw = 0;
+  static tensorrt10_subplugin *registeredRepresentation;
+
+  bool _configured{}; /**< Flag to keep track of whether this instance has been configured or not. */
+  int _device_id{}; /**< Device id of the gpu to use. */
+  gchar *_model_path{}; /**< engine file path */
+  std::filesystem::path _engine_fs_path{}; /**< filesystem path to engine file */
+  cudaStream_t _stream{}; /**< The cuda inference stream */
+
+  GstTensorsInfo _inputTensorMeta;
+  GstTensorsInfo _outputTensorMeta;
+
+  NvInferUniquePtr<nvinfer1::IRuntime> _Runtime{};
+  NvInferUniquePtr<nvinfer1::ICudaEngine> _Engine{};
+  NvInferUniquePtr<nvinfer1::IExecutionContext> _Context{};
+
+  std::vector<NvInferTensorInfo> _tensorrt10_input_tensor_infos{};
+  std::vector<NvInferTensorInfo> _tensorrt10_output_tensor_infos{};
+
+  void cleanup ();
+  void allocBuffer (void **buffer, gsize size);
+  void loadModel (const GstTensorFilterProperties *prop);
+  void checkUnifiedMemory () const;
+  void convertTensorsInfo (const std::vector<NvInferTensorInfo> &tensorrt10_tensor_infos,
+      GstTensorsInfo &info) const;
+  std::size_t getVolume (const nvinfer1::Dims &shape) const;
+  std::size_t getTensorRTDataTypeSize (nvinfer1::DataType tensorrt10_data_type) const;
+  tensor_type getNnStreamerDataType (nvinfer1::DataType tensorrt10_data_type) const;
+
+  void constructNetwork (NvInferUniquePtr<nvonnxparser::IParser> &parser) const;
+  void buildSaveEngine () const;
+  void loadEngine ();
+};
+
+const char *tensorrt10_subplugin::name = "tensorrt10";
+const accl_hw tensorrt10_subplugin::hw_list[] = {};
+
+/**
+ * @brief constructor of tensorrt10_subplugin
+ */
+tensorrt10_subplugin::tensorrt10_subplugin () : tensor_filter_subplugin ()
+{
+}
+
+/**
+ * @brief destructor of tensorrt10_subplugin
+ */
+tensorrt10_subplugin::~tensorrt10_subplugin ()
+{
+  cleanup ();
+}
+
+void
+tensorrt10_subplugin::cleanup ()
+{
+  if (!_configured) {
+    return;
+  }
+
+  gst_tensors_info_free (&_inputTensorMeta);
+  gst_tensors_info_free (&_outputTensorMeta);
+
+  for (auto &tensorrt10_tensor_info : _tensorrt10_input_tensor_infos) {
+    cudaFree (tensorrt10_tensor_info.buffer);
+    tensorrt10_tensor_info.buffer = nullptr;
+  }
+
+  if (_model_path != nullptr) {
+    g_free (_model_path);
+    _model_path = nullptr;
+  }
+
+  if (_stream) {
+    cudaStreamDestroy (_stream);
+  }
+}
+
+/**
+ * @brief Returns an empty instance.
+ * @return an empty instance
+ */
+tensor_filter_subplugin &
+tensorrt10_subplugin::getEmptyInstance ()
+{
+  return *(new tensorrt10_subplugin ());
+}
+
+/**
+ * @brief Configure the instance of the tensorrt10_subplugin.
+ * @param[in] prop property of tensor_filter instance
+ */
+void
+tensorrt10_subplugin::configure_instance (const GstTensorFilterProperties *prop)
+{
+  /* Set model path */
+  if (prop->num_models != 1 || !prop->model_files[0]) {
+    ml_loge ("TensorRT filter requires one engine model file.");
+    throw std::invalid_argument ("The .engine model file is not given.");
+  }
+  assert (_model_path == nullptr);
+  _model_path = g_strdup (prop->model_files[0]);
+
+  /* Make a TensorRT engine */
+  loadModel (prop);
+
+  _configured = true;
+}
+
+/**
+ * @brief Invoke the TensorRT model and get the inference result.
+ * @param[in] input The array of input tensors
+ * @param[out] output The array of output tensors
+ */
+void
+tensorrt10_subplugin::invoke (const GstTensorMemory *input, GstTensorMemory *output)
+{
+  ml_logi ("tensorrt10_subplugin::invoke");
+  g_assert (_configured);
+
+  if (!input)
+    throw std::runtime_error ("Invalid input buffer, it is NULL.");
+  if (!output)
+    throw std::runtime_error ("Invalid output buffer, it is NULL.");
+
+  cudaError_t status;
+
+  /* Copy input data to Cuda memory space */
+  for (std::size_t i = 0; i < _tensorrt10_input_tensor_infos.size (); ++i) {
+    const auto &tensorrt10_tensor_info = _tensorrt10_input_tensor_infos[i];
+    g_assert (tensorrt10_tensor_info.buffer_size == input[i].size);
+
+    status = cudaMemcpyAsync (tensorrt10_tensor_info.buffer, input[i].data,
+        input[i].size, cudaMemcpyHostToDevice, _stream);
+
+    if (status != cudaSuccess) {
+      ml_loge ("Failed to copy to cuda input buffer");
+      throw std::runtime_error ("Failed to copy to cuda input buffer");
+    }
+  }
+
+  for (std::size_t i = 0; i < _tensorrt10_output_tensor_infos.size (); ++i) {
+    const auto &tensorrt10_tensor_info = _tensorrt10_output_tensor_infos[i];
+    g_assert (tensorrt10_tensor_info.buffer_size == output[i].size);
+    allocBuffer (&output[i].data, output[i].size);
+    if (!_Context->setOutputTensorAddress (
+            tensorrt10_tensor_info.name, output[i].data)) {
+      ml_loge ("Unable to set output tensor address");
+      throw std::runtime_error ("Unable to set output tensor address");
+    }
+  }
+
+  /* Execute the network */
+  if (!_Context->enqueueV3 (_stream)) {
+    ml_loge ("Failed to execute the network");
+    throw std::runtime_error ("Failed to execute the network");
+  }
+
+  /* Wait for GPU to finish the inference */
+  status = cudaStreamSynchronize (_stream);
+
+  if (status != cudaSuccess) {
+    ml_loge ("Failed to synchronize the cuda stream");
+    throw std::runtime_error ("Failed to synchronize the cuda stream");
+  }
+}
+
+/**
+ * @brief Describe the subplugin's setting.
+ */
+void
+tensorrt10_subplugin::getFrameworkInfo (GstTensorFilterFrameworkInfo &info)
+{
+  info.name = name;
+  info.allow_in_place = FALSE;
+  info.allocate_in_invoke = TRUE;
+  info.run_without_model = FALSE;
+  info.verify_model_path = TRUE;
+  info.hw_list = hw_list;
+  info.num_hw = num_hw;
+}
+
+/**
+ * @brief Get the in/output tensors info.
+ */
+int
+tensorrt10_subplugin::getModelInfo (
+    model_info_ops ops, GstTensorsInfo &in_info, GstTensorsInfo &out_info)
+{
+  if (ops != GET_IN_OUT_INFO) {
+    return -ENOENT;
+  }
+
+  gst_tensors_info_copy (std::addressof (in_info), std::addressof (_inputTensorMeta));
+  gst_tensors_info_copy (std::addressof (out_info), std::addressof (_outputTensorMeta));
+
+  return 0;
+}
+
+/**
+ * @brief Override eventHandler to free Cuda data buffer.
+ */
+int
+tensorrt10_subplugin::eventHandler (event_ops ops, GstTensorFilterFrameworkEventData &data)
+{
+  if (ops == DESTROY_NOTIFY) {
+    if (data.data != nullptr) {
+      cudaFree (data.data);
+    }
+  }
+  return 0;
+}
+
+/**
+ * @brief Parses the model; changes the state of the provided parser.
+ */
+void
+tensorrt10_subplugin::constructNetwork (NvInferUniquePtr<nvonnxparser::IParser> &parser) const
+{
+  auto parsed = parser->parseFromFile (
+      _model_path, static_cast<int> (nvinfer1::ILogger::Severity::kWARNING));
+  if (!parsed) {
+    ml_loge ("Unable to parse onnx file");
+    throw std::runtime_error ("Unable to parse onnx file");
+  }
+}
+
+/**
+ * Builds and saves the .engine file.
+ */
+void
+tensorrt10_subplugin::buildSaveEngine () const
+{
+  auto builder = makeNvInferUniquePtr (nvinfer1::createInferBuilder (gLogger));
+  if (!builder) {
+    ml_loge ("Unable to create builder");
+    throw std::runtime_error ("Unable to create builder");
+  }
+
+  auto network = makeNvInferUniquePtr (builder->createNetworkV2 (0));
+  if (!network) {
+    ml_loge ("Unable to create network");
+    throw std::runtime_error ("Unable to create network");
+  }
+
+  auto config = makeNvInferUniquePtr (builder->createBuilderConfig ());
+  if (!config) {
+    ml_loge ("Unable to create builder config");
+    throw std::runtime_error ("Unable to create builder config");
+  }
+
+  auto parser = makeNvInferUniquePtr (nvonnxparser::createParser (*network, gLogger));
+  if (!parser) {
+    ml_loge ("Unable to create onnx parser");
+    throw std::runtime_error ("Unable to create onnx parser");
+  }
+
+  constructNetwork (parser);
+
+  auto host_memory
+      = makeNvInferUniquePtr (builder->buildSerializedNetwork (*network, *config));
+  if (!host_memory) {
+    ml_loge ("Unable to build serialized network");
+    throw std::runtime_error ("Unable to build serialized network");
+  }
+
+  std::ofstream engineFile (_engine_fs_path, std::ios::binary);
+  if (!engineFile) {
+    ml_loge ("Unable to open engine file for saving");
+    throw std::runtime_error ("Unable to open engine file for saving");
+  }
+  engineFile.write (static_cast<char *> (host_memory->data ()), host_memory->size ());
+}
+
+/**
+ * @brief Loads the .engine model file and makes a member object to be used for inference.
+ */
+void
+tensorrt10_subplugin::loadEngine ()
+{
+  // Create file
+  std::ifstream file (_engine_fs_path, std::ios::binary | std::ios::ate);
+  std::streamsize size = file.tellg ();
+  if (size < 0) {
+    ml_loge ("Unable to open engine file %s", std::string (_engine_fs_path).data ());
+    throw std::runtime_error ("Unable to open engine file");
+  }
+
+  file.seekg (0, std::ios::beg);
+  ml_logi ("Loading tensorrt10 engine from file: %s with buffer size: %" G_GUINT64_FORMAT,
+      std::string (_engine_fs_path).data (), size);
+
+  // Read file
+  std::vector<char> tensorrt10_engine_file_buffer (size);
+  if (!file.read (tensorrt10_engine_file_buffer.data (), size)) {
+    ml_loge ("Unable to read engine file %s", std::string (_engine_fs_path).data ());
+    throw std::runtime_error ("Unable to read engine file");
+  }
+
+  // Create an engine, a representation of the optimized model.
+  _Engine = NvInferUniquePtr<nvinfer1::ICudaEngine> (_Runtime->deserializeCudaEngine (
+      tensorrt10_engine_file_buffer.data (), tensorrt10_engine_file_buffer.size ()));
+  if (!_Engine) {
+    ml_loge ("Unable to deserialize tensorrt10 engine");
+    throw std::runtime_error ("Unable to deserialize tensorrt10 engine");
+  }
+}
+
+
+/**
+ * @brief Loads and interprets the model file.
+ * @param[in] prop: property of tensor_filter instance
+ */
+void
+tensorrt10_subplugin::loadModel (const GstTensorFilterProperties *prop)
+{
+  // GstTensorInfo *_info;
+
+  UNUSED (prop);
+
+  // Set the device index
+  auto ret = cudaSetDevice (_device_id);
+  if (ret != 0) {
+    int num_gpus;
+    cudaGetDeviceCount (&num_gpus);
+    ml_loge ("Unable to set GPU device index to: %d. CUDA-capable GPU(s): %d.",
+        _device_id, num_gpus);
+    throw std::runtime_error ("Unable to set GPU device index");
+  }
+
+  checkUnifiedMemory ();
+
+  // Parse model from .onnx and create .engine if necessary
+  std::filesystem::path model_fs_path (_model_path);
+  if (".onnx" == model_fs_path.extension ()) {
+    _engine_fs_path = std::filesystem::path ("/tmp") / model_fs_path.stem ();
+    _engine_fs_path += ".engine";
+    if (!std::filesystem::exists (_engine_fs_path)) {
+      buildSaveEngine ();
+      g_assert (std::filesystem::exists (_engine_fs_path));
+    }
+  } else if (".engine" == model_fs_path.extension ()) {
+    _engine_fs_path = model_fs_path;
+  } else {
+    ml_loge ("Unsupported model file extension %s",
+        std::string (model_fs_path.extension ()).data ());
+    throw std::runtime_error ("Unsupported model file extension");
+  }
+
+  // Create a runtime to deserialize the engine file.
+  _Runtime = makeNvInferUniquePtr (nvinfer1::createInferRuntime (gLogger));
+  if (!_Runtime) {
+    ml_loge ("Failed to create TensorRT runtime");
+    throw std::runtime_error ("Failed to create TensorRT runtime");
+  }
+
+  loadEngine ();
+
+  /* Create ExecutionContext object */
+  _Context = makeNvInferUniquePtr (_Engine->createExecutionContext ());
+  if (!_Context) {
+    ml_loge ("Failed to create the TensorRT ExecutionContext object");
+    throw std::runtime_error ("Failed to create the TensorRT ExecutionContext object");
+  }
+
+  // Create the cuda stream
+  cudaStreamCreate (&_stream);
+
+  // Get number of IO buffers
+  auto num_io_buffers = _Engine->getNbIOTensors ();
+  if (num_io_buffers <= 0) {
+    ml_loge ("Engine has no IO buffers");
+    throw std::runtime_error ("Engine has no IO buffers");
+  }
+
+  // Iterate the model io buffers
+  _tensorrt10_input_tensor_infos.clear ();
+  _tensorrt10_output_tensor_infos.clear ();
+  for (int buffer_index = 0; buffer_index < num_io_buffers; ++buffer_index) {
+    NvInferTensorInfo tensorrt10_tensor_info{};
+
+    // Get buffer name
+    tensorrt10_tensor_info.name = _Engine->getIOTensorName (buffer_index);
+
+    // Read and verify IO buffer shape
+    tensorrt10_tensor_info.shape
+        = _Engine->getTensorShape (tensorrt10_tensor_info.name);
+    if (tensorrt10_tensor_info.shape.d[0] == -1) {
+      ml_loge ("Dynamic batch size is not supported");
+      throw std::runtime_error ("Dynamic batch size is not supported");
+    }
+
+    // Get data type and buffer size info
+    tensorrt10_tensor_info.mode
+        = _Engine->getTensorIOMode (tensorrt10_tensor_info.name);
+    tensorrt10_tensor_info.dtype
+        = _Engine->getTensorDataType (tensorrt10_tensor_info.name);
+    tensorrt10_tensor_info.dtype_size
+        = getTensorRTDataTypeSize (tensorrt10_tensor_info.dtype);
+    tensorrt10_tensor_info.volume = getVolume (tensorrt10_tensor_info.shape);
+    tensorrt10_tensor_info.buffer_size
+        = tensorrt10_tensor_info.dtype_size * tensorrt10_tensor_info.volume;
+    ml_logd ("BUFFER SIZE: %" G_GUINT64_FORMAT, tensorrt10_tensor_info.buffer_size);
+
+    // Iterate the input and output buffers
+    if (tensorrt10_tensor_info.mode == nvinfer1::TensorIOMode::kINPUT) {
+
+      if (!_Context->setInputShape (
+              tensorrt10_tensor_info.name, tensorrt10_tensor_info.shape)) {
+        ml_loge ("Unable to set input shape");
+        throw std::runtime_error ("Unable to set input shape");
+      }
+
+      // Allocate only for input, memory for output is allocated in the invoke method.
+      allocBuffer (&tensorrt10_tensor_info.buffer, tensorrt10_tensor_info.buffer_size);
+      if (!_Context->setInputTensorAddress (
+              tensorrt10_tensor_info.name, tensorrt10_tensor_info.buffer)) {
+        ml_loge ("Unable to set input tensor address");
+        throw std::runtime_error ("Unable to set input tensor address");
+      }
+
+      _tensorrt10_input_tensor_infos.push_back (tensorrt10_tensor_info);
+
+    } else if (tensorrt10_tensor_info.mode == nvinfer1::TensorIOMode::kOUTPUT) {
+
+      _tensorrt10_output_tensor_infos.push_back (tensorrt10_tensor_info);
+
+    } else {
+      ml_loge ("TensorIOMode not supported");
+      throw std::runtime_error ("TensorIOMode not supported");
+    }
+  }
+
+  if (!_Context->allInputDimensionsSpecified ()) {
+    ml_loge ("Not all required dimensions were specified");
+    throw std::runtime_error ("Not all required dimensions were specified");
+  }
+
+  convertTensorsInfo (_tensorrt10_input_tensor_infos, _inputTensorMeta);
+  convertTensorsInfo (_tensorrt10_output_tensor_infos, _outputTensorMeta);
+}
+
+/**
+ * Converts the NvInferTensorInfo's to the nnstreamer GstTensorsInfo.
+ */
+void
+tensorrt10_subplugin::convertTensorsInfo (
+    const std::vector<NvInferTensorInfo> &tensorrt10_tensor_infos, GstTensorsInfo &info) const
+{
+  gst_tensors_info_init (std::addressof (info));
+  info.num_tensors = tensorrt10_tensor_infos.size ();
+
+  for (guint tensor_index = 0; tensor_index < info.num_tensors; ++tensor_index) {
+    const auto &tensorrt10_tensor_info = tensorrt10_tensor_infos[tensor_index];
+
+    // Set the nnstreamer GstTensorInfo properties
+    GstTensorInfo *tensor_info
+        = gst_tensors_info_get_nth_info (std::addressof (info), tensor_index);
+    tensor_info->name = g_strdup (tensorrt10_tensor_info.name);
+    tensor_info->type = getNnStreamerDataType (tensorrt10_tensor_info.dtype);
+
+    // Set tensor dimensions in reverse order
+    for (int dim_index = 0; dim_index < tensorrt10_tensor_info.shape.nbDims; ++dim_index) {
+      std::size_t from_dim_index = tensorrt10_tensor_info.shape.nbDims - dim_index - 1;
+      tensor_info->dimension[dim_index] = tensorrt10_tensor_info.shape.d[from_dim_index];
+    }
+  }
+}
+
+/**
+ * @brief Return whether Unified Memory is supported or not.
+ * @note After Cuda version 6, logical Unified Memory is supported in
+ * programming language level. However, if the target device is not supported,
+ * then cudaMemcpy() internally occurs and it makes performance degradation.
+ */
+void
+tensorrt10_subplugin::checkUnifiedMemory () const
+{
+  int version;
+
+  if (cudaRuntimeGetVersion (&version) != cudaSuccess) {
+    ml_loge ("Unable to get cuda runtime version");
+    throw std::runtime_error ("Unable to get cuda runtime version");
+  }
+
+  /* Unified memory requires at least CUDA-6 */
+  if (version < 6000) {
+    ml_loge ("Unified memory requires at least CUDA-6");
+    throw std::runtime_error ("Unified memory requires at least CUDA-6");
+  }
+
+  // Get device properties
+  cudaDeviceProp prop;
+  cudaGetDeviceProperties (&prop, _device_id);
+  if (prop.managedMemory == 0) {
+    ml_loge ("The current device does not support managedmemory");
+    throw std::runtime_error ("The current device does not support managedmemory");
+  }
+
+  // The cuda programming guide specifies at least compute capability version 5
+  //  https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#unified-memory-programming
+  if (prop.major < 5) {
+    ml_loge ("The minimum required compute capability for unified memory is version 5");
+    throw std::runtime_error (
+        "The minimum required compute capability for unified memory is version 5");
+  }
+}
+
+/**
+ * @brief Allocates a GPU buffer memory
+ * @param[out] buffer : pointer to allocated memory
+ * @param[in] size : allocation size in bytes
+ */
+void
+tensorrt10_subplugin::allocBuffer (void **buffer, gsize size)
+{
+  cudaError_t status = cudaMallocManaged (buffer, size);
+
+  if (status != cudaSuccess) {
+    ml_loge ("Failed to allocate Cuda memory");
+    throw std::runtime_error ("Failed to allocate Cuda memory");
+  }
+}
+
+/**
+ * @brief Calculates the volume (in elements, not in bytes) of the provided shape.
+ * @param[in] shape : The shape for which to calculate the volume.
+ */
+std::size_t
+tensorrt10_subplugin::getVolume (const nvinfer1::Dims &shape) const
+{
+  auto volume = 1;
+  for (auto i = 0; i < shape.nbDims; ++i) {
+    volume *= shape.d[i];
+  }
+  return volume;
+}
+
+/**
+ * @brief Get the size of the TensorRT data type.
+ * @param[in] tensorrt10_data_type : The TensorRT data type.
+ * @note: see also https://github.com/NVIDIA/TensorRT/blob/ccf119972b50299ba00d35d39f3938296e187f4e/samples/common/common.h#L539C1-L552C14
+ */
+std::size_t
+tensorrt10_subplugin::getTensorRTDataTypeSize (nvinfer1::DataType tensorrt10_data_type) const
+{
+  switch (tensorrt10_data_type) {
+    case nvinfer1::DataType::kINT64:
+      return 8;
+    case nvinfer1::DataType::kINT32:
+    case nvinfer1::DataType::kFLOAT:
+      return 4;
+    case nvinfer1::DataType::kBF16:
+    case nvinfer1::DataType::kHALF:
+      return 2;
+    case nvinfer1::DataType::kBOOL:
+    case nvinfer1::DataType::kUINT8:
+    case nvinfer1::DataType::kINT8:
+    case nvinfer1::DataType::kFP8:
+      return 1;
+    case nvinfer1::DataType::kINT4:
+    default:
+      ml_loge ("Element size is not implemented for data-type");
+  }
+  ml_loge ("Unable to determine tensorrt10 data type size");
+  throw std::runtime_error ("Unable to determine tensorrt10 data type size");
+}
+
+/**
+ * @brief Get the corresponding nnstreamer tensor_type based on the TensorRT data type.
+ * @param[in] tensorrt10_data_type : The TensorRT data type.
+ * @return The nnstreamer tensor_type.
+ */
+tensor_type
+tensorrt10_subplugin::getNnStreamerDataType (nvinfer1::DataType tensorrt10_data_type) const
+{
+  switch (tensorrt10_data_type) {
+    case nvinfer1::DataType::kINT64:
+      return _NNS_INT64;
+    case nvinfer1::DataType::kINT32:
+      return _NNS_INT32;
+    case nvinfer1::DataType::kFLOAT:
+      return _NNS_FLOAT32;
+    case nvinfer1::DataType::kBF16:
+    case nvinfer1::DataType::kHALF:
+      return _NNS_FLOAT16;
+    case nvinfer1::DataType::kBOOL:
+    case nvinfer1::DataType::kUINT8:
+      return _NNS_UINT8;
+    case nvinfer1::DataType::kINT8:
+      return _NNS_INT8;
+    case nvinfer1::DataType::kFP8:
+    case nvinfer1::DataType::kINT4:
+    default:
+      ml_loge ("Element size is not implemented for data type.");
+  }
+  ml_loge ("Unable to get the nnstreamer data type");
+  throw std::runtime_error ("Unable to get the nnstreamer data type");
+}
+
+tensorrt10_subplugin *tensorrt10_subplugin::registeredRepresentation = nullptr;
+
+/**
+ * @brief Register the tensorrt10_subplugin object.
+ */
+void
+tensorrt10_subplugin::init_filter_tensorrt10 (void)
+{
+  registeredRepresentation
+      = tensor_filter_subplugin::register_subplugin<tensorrt10_subplugin> ();
+}
+
+/**
+ * @brief Unregister the tensorrt10_subplugin object.
+ */
+void
+tensorrt10_subplugin::fini_filter_tensorrt10 (void)
+{
+  assert (registeredRepresentation != nullptr);
+  tensor_filter_subplugin::unregister_subplugin (registeredRepresentation);
+}
+
+/** @brief Initialize this object for tensor_filter subplugin runtime register */
+void
+_init_filter_tensorrt10 (void)
+{
+  tensorrt10_subplugin::init_filter_tensorrt10 ();
+}
+
+/** @brief Destruct the subplugin */
+void
+_fini_filter_tensorrt10 (void)
+{
+  tensorrt10_subplugin::fini_filter_tensorrt10 ();
+}
+
+} /* namespace tensorfilter_tensorrt10 */
+} /* namespace nnstreamer */
index a4555e1..527224e 100644 (file)
@@ -265,42 +265,68 @@ if not get_option('snpe-support').disabled()
   endif
 endif
 
-# tensorrt
-nvinfer_dep = dependency('', required: false)
-nvparsers_dep = dependency('', required: false)
+# cuda
+cuda_version_str = ''
+cuda_major = 0
+cuda_minor = 0
 cuda_dep = dependency('', required: false)
 cudart_dep = dependency('', required: false)
-if not get_option('tensorrt-support').disabled()
-  # check available cuda versions (11.0 and 10.2 are recommended)
-  cuda_vers = [
-    '11.0',
-    '10.2',
-    '10.1',
-    '10.0',
-    '9.2',
-    '9.1',
-    '9.0'
-  ]
+if not get_option('tensorrt-support').disabled() or not get_option('tensorrt10-support').disabled()
+  nvcc = find_program('nvcc', required: false)
+  if nvcc.found()
+    cuda_version_str = run_command(find_program('tools/cuda-version.sh'), check: true).stdout().strip()
+    message('$cuda_version_str: @0@'.format(cuda_version_str))
+
+    cuda_versions = cuda_version_str.split('.')
+    cuda_major = cuda_versions[0].to_int()
+    cuda_minor = cuda_versions[1].to_int()
+
+    cuda_dep = dependency('cuda-' + cuda_version_str, required: false)
+    cudart_dep = dependency('cudart-' + cuda_version_str, required: false)
+  endif
+endif
 
-  foreach ver : cuda_vers
-    cuda_dep = dependency('cuda-' + ver, required: false)
-    cudart_dep = dependency('cudart-' + ver, required: false)
-    if cuda_dep.found() and cudart_dep.found()
-      if ver != '11.0' and ver != '10.2'
-        message('Warning: the recommended cuda version is at least 10.2')
-      endif
-      break
-    endif
-  endforeach
+# tensorrt
+tensorrt_version_str = ''
+tensorrt_major = 0
+tensorrt_minor = 0
+nvinfer_dep = dependency('', required: false)
+nvuffparsers_dep = dependency('', required: false)
+nvonnxparser_dep = dependency('', required: false)
+if cuda_major > 0 and (not get_option('tensorrt-support').disabled() or not get_option('tensorrt10-support').disabled())
+
+  tensorrt_version_str = run_command(find_program('tools/tensorrt-version.sh'), check: true).stdout().strip()
+  message('$tensorrt_version_str: @0@'.format(tensorrt_version_str))
+
+  tensorrt_versions = tensorrt_version_str.split('.')
+  tensorrt_major = tensorrt_versions[0].to_int()
+  tensorrt_minor = tensorrt_versions[1].to_int()
 
   nvinfer_lib = cxx.find_library('nvinfer', required: false)
-  if nvinfer_lib.found() and cxx.check_header('NvInfer.h')
+  if nvinfer_lib.found() and cxx.has_header('NvInfer.h')
     nvinfer_dep = declare_dependency(dependencies: nvinfer_lib)
   endif
 
-  nvparsers_lib = cxx.find_library('nvparsers', required: false)
-  if nvparsers_lib.found() and cxx.check_header('NvUffParser.h')
-    nvparsers_dep = declare_dependency(dependencies: nvparsers_lib)
+  if tensorrt_major < 10 and not get_option('tensorrt-support').disabled()
+    nvuffparsers_lib = cxx.find_library('nvuffparsers', required: false)
+    if nvuffparsers_lib.found() and cxx.has_header('NvUffParser.h')
+      nvuffparsers_dep = declare_dependency(dependencies: nvuffparsers_lib)
+    endif
+  elif tensorrt_major >= 10 and not get_option('tensorrt10-support').disabled()
+    nvonnxparser_lib = cxx.find_library('nvonnxparser', required: false)
+    if nvonnxparser_lib.found() and cxx.has_header('NvOnnxParser.h')
+      nvonnxparser_dep = declare_dependency(dependencies: nvonnxparser_lib)
+    endif
+  endif
+endif
+
+# cuda/tensorrt version checks
+if cuda_version_str != '' and tensorrt_version_str != ''
+  if cuda_major < 11 and tensorrt_major >= 10
+    message('Warning: TensorRT 10 requires at least cuda 11.0')
+  endif
+  if cuda_major < 10 or (cuda_major == 10 and cuda_minor < 2)
+    message('Warning: the recommended cuda version is at least 10.2')
   endif
 endif
 
@@ -460,9 +486,13 @@ features = {
     'project_args': { 'ENABLE_PROTOBUF': 1 }
   },
   'tensorrt-support': {
-    'extra_deps': [ nvinfer_dep, nvparsers_dep, cuda_dep, cudart_dep ],
+    'extra_deps': [ nvinfer_dep, nvuffparsers_dep, cuda_dep, cudart_dep ],
     'project_args': { 'ENABLE_TENSORRT': 1 }
   },
+  'tensorrt10-support': {
+    'extra_deps': [ nvinfer_dep, nvonnxparser_dep, cuda_dep, cudart_dep ],
+    'project_args': { 'ENABLE_TENSORRT10': 1 }
+  },
   'grpc-support': {
     'extra_deps': [ grpc_dep, gpr_dep, grpcpp_dep ],
     'project_args': { 'ENABLE_GRPC': 1 }
index f26c3f6..d9da5cd 100644 (file)
@@ -18,6 +18,7 @@ option('snpe-support', type: 'feature', value: 'auto')
 option('protobuf-support', type: 'feature', value: 'auto')
 option('flatbuf-support', type: 'feature', value: 'auto')
 option('tensorrt-support', type: 'feature', value: 'auto')
+option('tensorrt10-support', type: 'feature', value: 'auto')
 option('grpc-support', type: 'feature', value: 'auto')
 option('lua-support', type: 'feature', value: 'auto')
 option('mqtt-support', type: 'feature', value: 'auto')
diff --git a/tests/nnstreamer_filter_tensorrt10/runTest.sh b/tests/nnstreamer_filter_tensorrt10/runTest.sh
new file mode 100755 (executable)
index 0000000..7ee2051
--- /dev/null
@@ -0,0 +1,92 @@
+#!/usr/bin/env bash
+##
+## SPDX-License-Identifier: LGPL-2.1-only
+##
+## @file runTest.sh
+## @author Suyeon Kim <suyeon5.kim@samsung.com>
+## @date Oct 30 2023
+## @brief SSAT Test Cases for NNStreamer
+##
+
+if [[ "$SSATAPILOADED" != "1" ]]; then
+    SILENT=0
+    INDEPENDENT=1
+    search="ssat-api.sh"
+    source $search
+    printf "${Blue}Independent Mode${NC}"
+fi
+
+# This is compatible with SSAT (https://github.com/myungjoo/SSAT)
+testInit $1
+
+# NNStreamer and plugins path for test
+PATH_TO_PLUGIN="../../build"
+
+if [[ -d $PATH_TO_PLUGIN ]]; then
+    ini_path="${PATH_TO_PLUGIN}/ext/nnstreamer/tensor_filter"
+    if [[ -d ${ini_path} ]]; then
+        check=$(ls ${ini_path} | grep tensorrt.so)
+        if [[ ! $check ]]; then
+            echo "Cannot find TensorRT shared lib"
+            report
+            exit
+        fi
+    else
+        echo "Cannot find ${ini_path}"
+    fi
+else
+    ini_file="/etc/nnstreamer.ini"
+    if [[ -f ${ini_file} ]]; then
+        path=$(grep "^filters" ${ini_file})
+        key=${path%=*}
+        value=${path##*=}
+
+        if [[ $key != "filters" ]]; then
+            echo "String Error"
+            report
+            exit
+        fi
+
+        if [[ -d ${value} ]]; then
+            check=$(ls ${value} | grep tensorrt.so)
+            if [[ ! $check ]]; then
+                echo "Cannot find TensorRT lib"
+                report
+                exit
+            fi
+        else
+            echo "Cannot find ${value}"
+            report
+            exit
+        fi
+    else
+        echo "Cannot identify nnstreamer.ini"
+        report
+        exit
+    fi
+fi
+
+PATH_TO_MODEL="../test_models/models/yolov5nu_224.onnx"
+PATH_TO_LABEL="../test_models/labels/coco.txt"
+PATH_TO_IMAGE="../test_models/data/orange.png"
+
+gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} \
+    filesrc location=${PATH_TO_IMAGE} ! \
+    pngdec ! \
+    videoscale ! \
+    imagefreeze ! \
+    videoconvert ! \
+    video/x-raw,width=224,height=224,format=RGB,framerate=0/1 ! \
+    tensor_converter ! \
+    tensor_transform mode=transpose option=1:2:0:3 ! \
+    tensor_transform mode=arithmetic option=typecast:float32,div:255 ! \
+    tensor_filter framework=tensorrt10 model=${PATH_TO_MODEL} ! \
+    tensor_transform mode=transpose option=1:0:2:3 ! \
+    tensor_decoder mode=bounding_boxes option1=yolov8 option2=${PATH_TO_LABEL} option3=1 option4=224:224 option5=224:224 ! \
+    multifilesink location=yolov5nu_result_%1d.log" \
+    1 0 0 $PERFORMANCE
+
+# Cleanup
+rm yolov5nu_result_*.log*
+
+report
diff --git a/tests/test_models/labels/coco.txt b/tests/test_models/labels/coco.txt
new file mode 100644 (file)
index 0000000..ec82f0f
--- /dev/null
@@ -0,0 +1,80 @@
+person
+bicycle
+car
+motorbike
+aeroplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+sofa
+potted plant
+bed
+dining table
+toilet
+tvmonitor
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
diff --git a/tools/cuda-version.sh b/tools/cuda-version.sh
new file mode 100755 (executable)
index 0000000..2a61665
--- /dev/null
@@ -0,0 +1,2 @@
+#!/bin/bash
+nvcc --version | grep release | awk '{print $5}' | tr ',' ' '
diff --git a/tools/tensorrt-version.sh b/tools/tensorrt-version.sh
new file mode 100755 (executable)
index 0000000..e1fff2b
--- /dev/null
@@ -0,0 +1,8 @@
+#!/bin/bash
+NV_INFER_VERSION_H=$(find /usr/include -iname "NvInferVersion.h")
+if [ -z ${NV_INFER_VERSION_H} ]; then
+  exit
+fi
+NV_TENSORRT_MAJOR=$(cat ${NV_INFER_VERSION_H} | grep NV_TENSORRT_MAJOR | awk '{ print $3 }')
+NV_TENSORRT_MINOR=$(cat ${NV_INFER_VERSION_H} | grep NV_TENSORRT_MINOR | awk '{ print $3 }')
+echo "${NV_TENSORRT_MAJOR}.${NV_TENSORRT_MINOR}"