[tensor_trainer] Create tensor_trainer initial element
authorhyunil park <hyunil46.park@samsung.com>
Fri, 28 Oct 2022 00:09:07 +0000 (09:09 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 2 Feb 2023 08:30:19 +0000 (17:30 +0900)
- Add basic function (event handler, finalize, start, stop, change state, property)
- Add property (framework, input, output, inputtype and outputtype)
- Add caps negotiation (transform_caps, fixate_caps and set_caps)
- Add checking pipeline's input tensor info and tensor's set property values
- Add finding framework and sub-plugin
- Add calling invoke function in transfrom function

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
gst/nnstreamer/elements/gsttensor_trainer.c [new file with mode: 0644]
gst/nnstreamer/elements/gsttensor_trainer.h [new file with mode: 0644]
gst/nnstreamer/elements/meson.build
gst/nnstreamer/registerer/nnstreamer.c

diff --git a/gst/nnstreamer/elements/gsttensor_trainer.c b/gst/nnstreamer/elements/gsttensor_trainer.c
new file mode 100644 (file)
index 0000000..69223a2
--- /dev/null
@@ -0,0 +1,989 @@
+/* SPDX-License-Identifier: LGPL-2.1-only */
+/**
+ * Copyright (C) 2022 Samsung Electronics Co., Ltd.
+ *
+ * @file       gsttensor_trainer.c
+ * @date       20 October 2022
+ * @brief      GStreamer plugin to train tensor data using NN Frameworks
+ * @see                https://github.com/nnstreamer/nnstreamer
+ * @author     Hyunil Park <hyunil46.park@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ * ## Example launch line
+ * |[
+ * gst-launch-1.0 videotestsrc !
+ *    video/x-raw, format=RGB, width=640, height=480 ! tensor_converter ! 
+ *    tensor_trainer input=3:640:480 inputtype=uint8 output=1:1:1:1 outputtype=uint8 !
+ *    tensor_sink
+ * ]|
+ *
+ */
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+#include <stdlib.h>
+#include <nnstreamer_subplugin.h>
+#include <nnstreamer_util.h>
+#include "gsttensor_trainer.h"
+
+/**
+ * @brief Default caps string for both sink and source pad.
+ */
+#define CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static, flexible }")
+
+/**
+ * @brief The capabilities of the sink pad
+ */
+static GstStaticPadTemplate sinktemplate = GST_STATIC_PAD_TEMPLATE ("sink",
+    GST_PAD_SINK,
+    GST_PAD_ALWAYS,
+    GST_STATIC_CAPS (CAPS_STRING));
+
+/**
+ * @brief The capabilities of the src pad
+ */
+static GstStaticPadTemplate srctemplate = GST_STATIC_PAD_TEMPLATE ("src",
+    GST_PAD_SRC,
+    GST_PAD_ALWAYS,
+    GST_STATIC_CAPS (CAPS_STRING));
+
+GST_DEBUG_CATEGORY_STATIC (gst_tensor_trainer_debug);
+#define GST_CAT_DEFAULT gst_tensor_trainer_debug
+#define gst_tensor_trainer_parent_class parent_class
+G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_BASE_TRANSFORM);
+
+/**
+ * @brief Default framework property value
+ */
+#define DEFAULT_FRAMEWORK "nntrainer"
+
+/**
+ * @brief Default string property value 
+ */
+#define DEFAULT_STR_PROP_VALUE ""
+
+/**
+ * @brief Default string property value 
+ */
+enum
+{
+  PROP_0,
+  PROP_FRAMEWORK,
+  PROP_INPUT_DIM,
+  PROP_OUTPUT_DIM,
+  PROP_INPUT_TYPE,
+  PROP_OUTPUTTYPE
+};
+
+static void gst_tensor_trainer_set_property (GObject * object, guint prop_id,
+    const GValue * value, GParamSpec * pspec);
+static void gst_tensor_trainer_get_property (GObject * object, guint prop_id,
+    GValue * value, GParamSpec * pspec);
+static gboolean gst_tensor_trainer_start (GstBaseTransform * trans);
+static gboolean gst_tensor_trainer_stop (GstBaseTransform * trans);
+static void gst_tensor_trainer_finalize (GObject * object);
+static gboolean gst_tensor_trainer_sink_event (GstBaseTransform * trans,
+    GstEvent * event);
+static gboolean gst_tensor_trainer_src_event (GstBaseTransform * trans,
+    GstEvent * event);
+static GstFlowReturn gst_tensor_trainer_transform (GstBaseTransform * trans,
+    GstBuffer * inbuf, GstBuffer * outbuf);
+static GstStateChangeReturn gst_tensor_trainer_change_state (GstElement *
+    element, GstStateChange transition);
+static GstCaps *gst_tensor_trainer_transform_caps (GstBaseTransform * trans,
+    GstPadDirection direction, GstCaps * caps, GstCaps * filter);
+static GstCaps *gst_tensor_trainer_fixate_caps (GstBaseTransform * trans,
+    GstPadDirection direction, GstCaps * caps, GstCaps * othercaps);
+static gboolean gst_tensor_trainer_set_caps (GstBaseTransform * trans,
+    GstCaps * incaps, GstCaps * outcaps);
+static gboolean gst_tensor_trainer_transform_size (GstBaseTransform * trans,
+    GstPadDirection direction, GstCaps * caps, gsize size, GstCaps * othercaps,
+    gsize * othersize);
+
+static void
+gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer, const gchar * fw_name);
+static void gst_tensor_trainer_set_prop_dimension (GstTensorTrainer * trainer,
+    const GValue * value, const gboolean is_input);
+static void gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
+    const GValue * value, const gboolean is_input);
+static const GstTensorFilterFramework
+    *gst_tensor_trainer_find_best_framework (const char *names);
+static void gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
+    const char *name);
+static void gst_tensor_trainer_create_framework (GstTensorTrainer * trainer);
+static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
+    guint index, gboolean is_input);
+
+/**
+ * @brief initialize the tensor_trainer's class
+ */
+static void
+gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
+{
+  GObjectClass *gobject_class;
+  GstBaseTransformClass *trans_class;
+  GstElementClass *gstelement_class;
+
+  GST_DEBUG_CATEGORY_INIT (GST_CAT_DEFAULT, "tensor_trainer", 0,
+      "Tensor trainer to train neural network model");
+
+  gobject_class = G_OBJECT_CLASS (klass);
+  trans_class = GST_BASE_TRANSFORM_CLASS (klass);
+  gstelement_class = GST_ELEMENT_CLASS (klass);
+
+  gobject_class->set_property =
+      GST_DEBUG_FUNCPTR (gst_tensor_trainer_set_property);
+  gobject_class->get_property =
+      GST_DEBUG_FUNCPTR (gst_tensor_trainer_get_property);
+  gobject_class->finalize = GST_DEBUG_FUNCPTR (gst_tensor_trainer_finalize);
+
+  /* Called when the element's state changes */
+  gstelement_class->change_state =
+      GST_DEBUG_FUNCPTR (gst_tensor_trainer_change_state);
+
+  /* Called when the element starts processing */
+  trans_class->start = GST_DEBUG_FUNCPTR (gst_tensor_trainer_start);
+  /* Called when the element stop processing */
+  trans_class->stop = GST_DEBUG_FUNCPTR (gst_tensor_trainer_stop);
+
+  /* Event handler on sink pad or src pad */
+  trans_class->sink_event = GST_DEBUG_FUNCPTR (gst_tensor_trainer_sink_event);
+  trans_class->src_event = GST_DEBUG_FUNCPTR (gst_tensor_trainer_src_event);
+
+  /* Transforms incoming buffer */
+  trans_class->transform = GST_DEBUG_FUNCPTR (gst_tensor_trainer_transform);
+
+  /* Caps Negotiation */
+  trans_class->transform_caps =
+      GST_DEBUG_FUNCPTR (gst_tensor_trainer_transform_caps);
+  trans_class->fixate_caps = GST_DEBUG_FUNCPTR (gst_tensor_trainer_fixate_caps);
+  trans_class->set_caps = GST_DEBUG_FUNCPTR (gst_tensor_trainer_set_caps);
+
+  /* Allocation initial outbuffer size */
+  trans_class->transform_size =
+      GST_DEBUG_FUNCPTR (gst_tensor_trainer_transform_size);
+
+  /* Install properties for tensor_trainer */
+  g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
+      g_param_spec_string ("framework", "Framework", "Neural network framework",
+          DEFAULT_FRAMEWORK,
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_INPUT_DIM,
+      g_param_spec_string ("input-dim", "Input dimension",
+          "Input tensors dimension from inner array, up to 4 dimensions ?", "",
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_OUTPUT_DIM,
+      g_param_spec_string ("output-dim", "Output dimension",
+          "Output tensors dimension from inner array, up to 4 dimensions ?", "",
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_INPUT_TYPE,
+      g_param_spec_string ("input-type", "Input tensor element type",
+          "Type of each element of the input tensor ?", "",
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  g_object_class_install_property (gobject_class, PROP_OUTPUT_TYPE,
+      g_param_spec_string ("output-type", "Output tensor element type",
+          "Type of each element of the input tensor ?", "",
+          G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
+          G_PARAM_STATIC_STRINGS));
+
+  gst_element_class_set_details_simple (gstelement_class, "TensorTrainer",
+      "Trainer/Tensor", "Train tensor data using NN Frameworks",
+      "Samsung Electronics Co., Ltd.");
+
+  /* Add pad template */
+  gst_element_class_add_pad_template (gstelement_class,
+      gst_static_pad_template_get (&srctemplate));
+  gst_element_class_add_pad_template (gstelement_class,
+      gst_static_pad_template_get (&sinktemplate));
+}
+
+/**
+ * @brief Initialize tensor_trainer.
+ */
+static void
+gst_tensor_trainer_init (GstTensorTrainer * trainer)
+{
+  GST_DEBUG ("<ENTER>");
+  trainer->fw_name = g_strdup (DEFAULT_FRAMEWORK);
+  trainer->input_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->output_dimensions = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->input_type = g_strdup (DEFAULT_STR_PROP_VALUE);
+  trainer->output_type = g_strdup (DEFAULT_STR_PROP_VALUE);
+
+  trainer->fw = NULL;
+  trainer->fw_opened = 0; /* for test */
+  trainer->configured = 0;
+  trainer->input_configured = 0;
+  trainer->output_configured = 0;
+  trainer->inputtype_configured = 0;
+  trainer->outputtype_configured = 0;
+}
+
+/**
+ * @brief Function to finalize instance.
+ */
+static void
+gst_tensor_trainer_finalize (GObject * object)
+{
+  GstTensorTrainer *trainer;
+
+  trainer = GST_TENSOR_TRAINER (object);
+
+  g_free (trainer->fw_name);
+  g_free (trainer->input_dimensions);
+  g_free (trainer->output_dimensions);
+  g_free (trainer->input_type);
+  g_free (trainer->output_type);
+
+  G_OBJECT_CLASS (parent_class)->finalize (object);
+}
+
+/**
+ * @brief Setter for tensor_trainsink properties.
+ */
+static void
+gst_tensor_trainer_set_property (GObject * object, guint prop_id,
+    const GValue * value, GParamSpec * pspec)
+{
+  GstTensorTrainer *trainer;
+
+  trainer = GST_TENSOR_TRAINER (object);
+
+  switch (prop_id) {
+
+    case PROP_FRAMEWORK:
+      gst_tensor_trainer_set_prop_framework (trainer, g_value_get_string (value));
+      break;
+    case PROP_INPUT_DIM:
+      gst_tensor_trainer_set_prop_dimension (trainer, value, TRUE);
+      break;
+    case PROP_OUTPUT_DIM:
+      gst_tensor_trainer_set_prop_dimension (trainer, value, FALSE);
+      break;
+    case PROP_INPUT_TYPE:
+      gst_tensor_trainer_set_prop_type (trainer, value, TRUE);
+      break;
+    case PROP_OUTPUT_TYPE:
+      gst_tensor_trainer_set_prop_type (trainer, value, FALSE);
+      break;
+    default:
+      G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
+      break;
+  }
+}
+
+/**
+ * @brief Getter tensor_trainsink properties.
+ */
+static void
+gst_tensor_trainer_get_property (GObject * object, guint prop_id,
+    GValue * value, GParamSpec * pspec)
+{
+  GstTensorTrainer *trainer;
+
+  trainer = GST_TENSOR_TRAINER (object);
+
+  switch (prop_id) {
+    case PROP_FRAMEWORK:
+      g_value_set_string (value, trainer->fw_name);
+      break;
+    case PROP_INPUT_DIM:
+      g_value_set_string (value, trainer->input_dimensions);
+      break;
+    case PROP_OUTPUT_DIM:
+      g_value_set_string (value, trainer->output_dimensions);
+      break;
+    case PROP_INPUT_TYPE:
+      g_value_set_string (value, trainer->input_type);
+      break;
+    case PROP_OUTPUT_TYPE:
+      g_value_set_string (value, trainer->output_type);
+      break;
+    default:
+      G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
+      break;
+  }
+}
+
+/**
+ * @brief Change state of tensor_trainsink.
+ */
+static GstStateChangeReturn
+gst_tensor_trainer_change_state (GstElement * element,
+    GstStateChange transition)
+{
+  GstTensorTrainer *trainer = GST_TENSOR_TRAINER (element);
+  GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS;
+
+  switch (transition) {
+    case GST_STATE_CHANGE_NULL_TO_READY:
+      GST_INFO_OBJECT (trainer, "NULL_TO_READY");
+      break;
+
+    case GST_STATE_CHANGE_READY_TO_PAUSED:
+      GST_INFO_OBJECT (trainer, "READY_TO_PAUSED");
+      break;
+
+    case GST_STATE_CHANGE_PAUSED_TO_PLAYING:
+      GST_INFO_OBJECT (trainer, "PAUSED_TO_PLAYING");
+      /* start or resume model train */
+      break;
+
+    default:
+      break;
+  }
+
+  ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition);
+
+  switch (transition) {
+    case GST_STATE_CHANGE_PLAYING_TO_PAUSED:
+      GST_INFO_OBJECT (trainer, "PLAYING_TO_PAUSED");
+      /* pause model train */
+      break;
+
+    case GST_STATE_CHANGE_PAUSED_TO_READY:
+      GST_INFO_OBJECT (trainer, "PAUSED_TO_READY");
+      /* stop model train */
+      break;
+
+    case GST_STATE_CHANGE_READY_TO_NULL:
+      GST_INFO_OBJECT (trainer, "READY_TO_NULL");
+      /* destroy or reset model ? */
+      break;
+
+    default:
+      break;
+  }
+
+  return ret;
+}
+
+/**
+ * @brief Event handler for sink pad of tensor_trainer
+ */
+static gboolean
+gst_tensor_trainer_sink_event (GstBaseTransform * trans, GstEvent * event)
+{
+  GstTensorTrainer *trainer;
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  GST_INFO_OBJECT (trainer, "sink pad got event (%s)",
+      gst_event_type_get_name (GST_EVENT_TYPE (event)));
+
+  switch (GST_EVENT_TYPE (event)) {
+    case GST_EVENT_EOS:
+      GST_INFO_OBJECT (trainer, "get GST_EVENT_EOS event..state is %d",
+          GST_STATE (trainer));
+      break;
+    case GST_EVENT_FLUSH_START:
+      GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_START event");
+      break;
+    case GST_EVENT_FLUSH_STOP:
+      GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_STOP event");
+      break;
+    default:
+      break;
+  }
+
+  return GST_BASE_TRANSFORM_CLASS (parent_class)->sink_event (trans, event);
+}
+
+/**
+ * @brief Event handler for src pad of tensor_trainer
+ */
+static gboolean
+gst_tensor_trainer_src_event (GstBaseTransform * trans, GstEvent * event)
+{
+  GstTensorTrainer *trainer;
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  GST_INFO_OBJECT (trainer, "src pad got event (%s)",
+      gst_event_type_get_name (GST_EVENT_TYPE (event)));
+
+  switch (GST_EVENT_TYPE (event)) {
+    case GST_EVENT_EOS:
+      GST_INFO_OBJECT (trainer, "get GST_EVENT_EOS event..state is %d",
+          GST_STATE (trainer));
+      break;
+    case GST_EVENT_FLUSH_START:
+      GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_START event");
+      break;
+    case GST_EVENT_FLUSH_STOP:
+      GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_STOP event");
+      break;
+    default:
+      break;
+  }
+
+  return GST_BASE_TRANSFORM_CLASS (parent_class)->src_event (trans, event);
+}
+
+
+/**
+ * @brief Called when the element starts processing. optional vmethod of BaseTransform
+ */
+static gboolean
+gst_tensor_trainer_start (GstBaseTransform * trans)
+{
+  GstTensorTrainer *trainer;
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  if (trainer->fw_name)
+    gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
+  if (trainer->fw) {
+    /* create, compile */
+    gst_tensor_trainer_create_framework (trainer);
+  }
+
+  return TRUE;
+}
+
+/**
+ * @brief Called when the element stops processing. optional vmethod of BaseTransform
+ */
+static gboolean
+gst_tensor_trainer_stop (GstBaseTransform * trans)
+{
+  GstTensorTrainer *trainer;
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  UNUSED (trainer);
+
+  return TRUE;
+}
+
+/**
+ * @brief Transforms one incoming buffer to one outgoing buffer
+ */
+static GstFlowReturn
+gst_tensor_trainer_transform (GstBaseTransform * trans, GstBuffer * inbuf,
+    GstBuffer * outbuf)
+{
+  GstTensorTrainer *trainer;
+  gint ret = -1;
+  guint mem_blocks, i;
+  gsize header_size, expected;
+  gboolean in_flexible, out_flexible;
+  GstMemory *in_mem[NNS_TENSOR_SIZE_LIMIT] = { 0, };
+  GstMapInfo in_info[NNS_TENSOR_SIZE_LIMIT];
+  GstMemory *out_mem[NNS_TENSOR_SIZE_LIMIT] = { 0, };
+  GstMapInfo out_info[NNS_TENSOR_SIZE_LIMIT];
+  GstTensorMemory in_tensors[NNS_TENSOR_SIZE_LIMIT];
+  GstTensorMemory invoke_tensors[NNS_TENSOR_SIZE_LIMIT];
+  GstTensorMemory out_tensors[NNS_TENSOR_SIZE_LIMIT];
+  GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
+  GstTensorMetaInfo out_meta[NNS_TENSOR_SIZE_LIMIT];
+
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  /* Get all input tensors from inbuf */
+  mem_blocks = gst_buffer_n_memory (inbuf);
+  for (i = 0; i < mem_blocks; i++) {
+    in_mem[i] = gst_buffer_peek_memory (inbuf, i);
+    if (!gst_memory_map (in_mem[i], &in_info[i], GST_MAP_READ)) {
+      GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
+      goto error;
+    }
+    in_flexible =
+        gst_tensor_pad_caps_is_flexible (GST_BASE_TRANSFORM_SINK_PAD (trans));
+    /* Get header size */
+    header_size = 0;
+    if (in_flexible) {
+      gst_tensor_meta_info_parse_header (&in_meta[i], &in_info[i].data);
+      header_size = gst_tensor_meta_info_get_header_size (&in_meta[i]);
+      GST_INFO ("flexible header size:%zd", header_size);
+    } else {
+      GST_INFO ("not flexible header size:%zd", header_size);
+    }
+
+    in_tensors[i].data = in_info[i].data + header_size;
+    in_tensors[i].size = in_info[i].size - header_size;
+  }
+
+  /* Prepare tensor to invoke */
+  /* Check number of input tensors */
+  if (mem_blocks != trainer->input_meta.num_tensors) {
+    GST_ERROR_OBJECT (trainer, "Invalid memory blocks(%d),"
+        "number of input tensors may be (%d)", mem_blocks,
+        trainer->input_meta.num_tensors);
+    goto error;
+  }
+
+  /* Check size of input tensors */
+  for (i = 0; i < trainer->input_meta.num_tensors; i++) {
+    expected = gst_tensor_trainer_get_tensor_size (trainer, i, TRUE);
+    if (expected != in_tensors[i].size) {
+      GST_ERROR_OBJECT (trainer, "Invalid tensor size (%u'th memory chunk: %zd)"
+          ", expected size (%zd)", i, in_tensors[i].size, expected);
+      goto error;
+    }
+    /* Copy to data pointer */
+    invoke_tensors[i] = in_tensors[i];
+  }
+
+  /* Prepare output tensor */
+  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+    out_tensors[i].data = NULL;
+    out_tensors[i].size =
+        gst_tensor_trainer_get_tensor_size (trainer, i, FALSE);
+
+    /* Get header size */
+    header_size = 0;
+    out_flexible =
+        gst_tensor_pad_caps_is_flexible (GST_BASE_TRANSFORM_SRC_PAD (trans));
+    if (out_flexible) {
+      gst_tensor_info_convert_to_meta (&trainer->output_meta.info[i],
+          &out_meta[i]);
+      header_size = gst_tensor_meta_info_get_header_size (&out_meta[i]);
+      GST_INFO ("flexible header size:%zd", header_size);
+    } else {
+      GST_INFO ("not flexible header size:%zd", header_size);
+    }
+
+    out_mem[i] =
+        gst_allocator_alloc (NULL, out_tensors[i].size + header_size, NULL);
+    if (!out_mem[i]) {
+      GST_ERROR_OBJECT (trainer, "Failed to allocate memory");
+      goto error;
+    }
+
+    if (!gst_memory_map (out_mem[i], &out_info[i], GST_MAP_WRITE)) {
+      GST_ERROR_OBJECT (trainer, "Could not map in_mem[%d] GstMemory", i);
+      goto error;
+    }
+
+    out_tensors[i].data = out_info[i].data + header_size;
+
+    /* Append header */
+    if (out_flexible) {
+      if (!gst_tensor_meta_info_update_header (&out_meta[i], out_info[i].data)) {
+        GST_ERROR_OBJECT (trainer, "Failed to update header ");
+        goto error;
+      }
+    }
+  }
+
+  /* Call the trainer-subplugin callback, invoke */
+  ret =
+      trainer->fw->invoke_NN (&trainer->prop, &trainer->privateData,
+      invoke_tensors, out_tensors);
+
+  /* Free map info and handle */
+  for (i = 0; i < mem_blocks; i++)
+    gst_memory_unmap (in_mem[i], &in_info[i]);
+
+  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+    gst_memory_unmap (out_mem[i], &out_info[i]);
+    //if (ret != 0) {
+    //  gst_allocator_free (out_mem[i]->allocator, out_mem[i]);
+    //}
+  }
+
+  if (ret < 0) {
+    GST_ERROR_OBJECT (trainer, "Invoke error");
+    // return GST_FLOW_ERROR;
+  } else if (ret > 0) {
+    /* drop this buffer */
+    // return GST_BASE_TRANSFORM_FLOW_DROPPED;
+  }
+
+  GST_INFO ("out buffer size : %zd", gst_buffer_get_size (outbuf));
+  /*Update result */
+  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+    /* append the memory block to outbuf */
+    gst_buffer_append_memory (outbuf, out_mem[i]);
+  }
+  GST_INFO ("after out buffer size : %zd", gst_buffer_get_size (outbuf));
+
+  return GST_FLOW_OK;
+
+error:
+  mem_blocks = gst_buffer_n_memory (inbuf);
+  for (i = 0; i < mem_blocks; i++) {
+    if (in_mem[i])
+      gst_memory_unmap (in_mem[i], &in_info[i]);
+  }
+
+  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
+    if (out_mem[i]) {
+      gst_memory_unmap (out_mem[i], &out_info[i]);
+      gst_allocator_free (out_mem[i]->allocator, out_mem[i]);
+    }
+  }
+
+  return GST_FLOW_ERROR;
+}
+
+/**
+ * @brief configure tensor-srcpad cap from "proposed" cap.
+ */
+static GstCaps *
+gst_tensor_trainer_transform_caps (GstBaseTransform * trans,
+    GstPadDirection direction, GstCaps * caps, GstCaps * filter)
+{
+  GstTensorTrainer *trainer;
+  GstCaps *result;
+  GstPad *pad;
+  GstStructure *structure;
+  gboolean configured = FALSE;
+
+  GstTensorsConfig in_config, out_config;
+
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  GST_DEBUG_OBJECT (trans,
+      "[In direction %s] Transforming caps %" GST_PTR_FORMAT "",
+      (direction == GST_PAD_SINK) ? "sink" : "src", caps);
+
+  if (direction == GST_PAD_SRC)
+    pad = GST_BASE_TRANSFORM_SINK_PAD (trans);
+  else
+    pad = GST_BASE_TRANSFORM_SRC_PAD (trans);
+
+  gst_tensors_config_init (&in_config);
+  gst_tensors_config_init (&out_config);
+  structure = gst_caps_get_structure (caps, 0);
+  gst_tensors_config_from_structure (&in_config, structure);
+
+  /* Set framerate from input config */
+  out_config.rate_n = in_config.rate_n;
+  out_config.rate_d = in_config.rate_d;
+
+  /* Need to set property (input, inputtype, output, outputtype)
+     for trainer->input_meta and trainer->output_meta */
+  if (direction == GST_PAD_SRC) {
+    if (trainer->input_configured) {
+      gst_tensors_info_copy (&out_config.info, &trainer->input_meta);
+      configured = TRUE;
+    }
+  } else {
+    if (trainer->output_configured) {
+      configured = TRUE;
+      gst_tensors_info_copy (&out_config.info, &trainer->output_meta);
+    }
+  }
+
+  if (configured)
+    /* Output info may be configured */
+    result = gst_tensor_pad_possible_caps_from_config (pad, &out_config);
+  else
+    result = gst_caps_from_string (CAPS_STRING);
+
+  GST_DEBUG_OBJECT (trans, "caps intersect without filter %" GST_PTR_FORMAT, result);
+
+  if (filter) {
+    GstCaps *intersection;
+    intersection =
+        gst_caps_intersect_full (result, filter, GST_CAPS_INTERSECT_FIRST);
+    gst_caps_unref (result);
+    result = intersection;
+    GST_DEBUG_OBJECT (trans, "result caps %" GST_PTR_FORMAT, result);
+  }
+  gst_tensors_config_free (&in_config);
+  gst_tensors_config_free (&out_config);
+
+  return result;
+}
+
+/**
+ * @brief fixate caps. required vmethod of GstBaseTransform.
+ */
+static GstCaps *
+gst_tensor_trainer_fixate_caps (GstBaseTransform * trans,
+    GstPadDirection direction, GstCaps * caps, GstCaps * othercaps)
+{
+  UNUSED (direction);
+  UNUSED (caps);
+
+  othercaps = gst_caps_fixate (othercaps);
+  GST_DEBUG_OBJECT (trans, "fixated to %" GST_PTR_FORMAT, othercaps);
+
+  return othercaps;
+}
+
+/**
+ * @brief set caps. required vmethod of GstBaseTransform.
+ */
+static gboolean
+gst_tensor_trainer_set_caps (GstBaseTransform * trans, GstCaps * incaps,
+    GstCaps * outcaps)
+{
+  GstTensorTrainer *trainer;
+  GstStructure *structure;
+  GstTensorsConfig in_config;
+
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  GST_DEBUG_OBJECT (trainer, "[incaps] : %" GST_PTR_FORMAT, incaps);
+  GST_DEBUG_OBJECT (trainer, "[outcaps] : %" GST_PTR_FORMAT, outcaps);
+
+  gst_tensors_config_init (&in_config);
+  structure = gst_caps_get_structure (incaps, 0);
+  gst_tensors_config_from_structure (&in_config, structure);
+
+  if (!gst_tensors_info_is_equal (&in_config.info, &trainer->input_meta)) {
+    GST_ERROR_OBJECT (trainer,
+        "The input tensors info is different between incaps and set property value. "
+        "Please check pipeline's input tensor info and tensor_trainer's set property values"
+        "(input, inputtype, output and outputtype)");
+    return FALSE;
+  }
+
+  gst_tensors_config_free (&in_config);
+
+  return TRUE;
+}
+
+/**
+ * @brief Handle "PROP_FRAMEWORK" for set-property
+ */
+static void
+gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer, const gchar * fw_name)
+{
+  g_free (trainer->fw_name);
+  trainer->fw_name = g_strdup (fw_name);
+  GST_INFO_OBJECT (trainer, "framework: %s", trainer->fw_name);
+
+  /** @todo Check valid framework */
+}
+
+/**
+ * @brief Handle "PROP_INPUT_DIM" and "PROP_OUTPUT_DIM" for set-property
+ */
+static void
+gst_tensor_trainer_set_prop_dimension (GstTensorTrainer * trainer,
+    const GValue * value, const gboolean is_input)
+{
+  GstTensorsInfo *info;
+  unsigned int *rank;
+  guint num_dims;
+  gchar **str_dims;
+  guint i;
+
+  if ((is_input && trainer->input_configured) || (!is_input
+          && trainer->output_configured)) {
+    GST_ERROR_OBJECT (trainer,
+        "Cannot change %s dimension" "the element/pipeline is configured.",
+        (is_input) ? "input" : "output");
+    return;
+  }
+
+  if (is_input) {
+    info = &trainer->input_meta;
+    rank = trainer->input_ranks;
+    trainer->input_configured = TRUE;
+    g_free (trainer->input_dimensions);
+    trainer->input_dimensions = g_value_dup_string (value);
+    GST_INFO_OBJECT (trainer, "input: %s", trainer->input_dimensions);
+  } else {
+    info = &trainer->output_meta;
+    rank = trainer->output_ranks;
+    trainer->output_configured = TRUE;
+    g_free (trainer->output_dimensions);
+    trainer->output_dimensions = g_value_dup_string (value);
+    GST_INFO_OBJECT (trainer, "output: %s", trainer->output_dimensions);
+  }
+
+  str_dims = g_strsplit_set (g_value_get_string (value), ",.", -1);
+  num_dims = g_strv_length (str_dims);
+
+  if (num_dims > NNS_TENSOR_SIZE_LIMIT) {
+    GST_WARNING_OBJECT (trainer, "Invalid param, dimensions(%d) max(%d)",
+        num_dims, NNS_TENSOR_SIZE_LIMIT);
+    num_dims = NNS_TENSOR_SIZE_LIMIT;
+  }
+
+  for (i = 0; i < num_dims; i++)
+    rank[i] = gst_tensor_parse_dimension (str_dims[i], info->info[i].dimension);
+
+  info->num_tensors = num_dims;
+
+  g_strfreev (str_dims);
+}
+
+/**
+ * @brief Handle "PROP_INPUT_TYPE" and "PROP_OUTPUT_TYPE" for set-property
+ */
+static void
+gst_tensor_trainer_set_prop_type (GstTensorTrainer * trainer,
+    const GValue * value, const gboolean is_input)
+{
+  GstTensorsInfo *info;
+  guint num_types;
+
+  if ((is_input && trainer->inputtype_configured) || (!is_input
+          && trainer->outputtype_configured)) {
+    GST_ERROR_OBJECT (trainer,
+        "Cannot change %stype" "the element/pipeline is configured.",
+        (is_input) ? "input" : "output");
+    return;
+  }
+
+  if (is_input) {
+    info = &trainer->input_meta;
+    trainer->inputtype_configured = TRUE;
+    g_free (trainer->input_type);
+    trainer->input_type = g_value_dup_string (value);
+    GST_INFO_OBJECT (trainer, "inputtype : %s", trainer->input_type);
+  } else {
+    info = &trainer->output_meta;
+    trainer->outputtype_configured = TRUE;
+    g_free (trainer->output_type);
+    trainer->output_type = g_value_dup_string (value);
+    GST_INFO_OBJECT (trainer, "outputtype : %s", trainer->output_type);
+  }
+
+  num_types =
+      gst_tensors_info_parse_types_string (info, g_value_get_string (value));
+
+  info->num_tensors = num_types;
+}
+
+/**
+ * @brief Find Trainer sub-plugin with the name.
+ */
+static void
+gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name)
+{
+  const GstTensorFilterFramework *fw = NULL;
+  gchar *str;
+  g_return_if_fail (name != NULL);
+
+  GST_INFO ("find framework: %s", name);
+
+  /* Need to add trainer type to subpluginType */
+  fw = get_subplugin (NNS_SUBPLUGIN_FILTER, name);
+
+  if (fw == NULL) {
+    /*Get sub-plugin priority from ini file and find sub-plugin */
+    str = nnsconf_get_custom_value_string (name, "subplugin_priority");
+    fw = gst_tensor_trainer_find_best_framework (str);
+    g_free (str);
+  }
+
+  if (fw == NULL) {
+    /* Check the filter-alias from ini file */
+    str = nnsconf_get_custom_value_string ("filter-aliases", name);
+    fw = gst_tensor_trainer_find_best_framework (str);
+    g_free (str);
+  }
+
+  if (fw) {
+    GST_INFO_OBJECT (trainer, "find framework %s:%p", trainer->fw_name, fw);
+    trainer->fw = fw;
+  } else {
+    GST_ERROR_OBJECT (trainer, "Can not find framework(%s)", trainer->fw_name);
+  }
+}
+
+/**
+ * @brief Create NN framework.
+ */
+static void
+gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
+{
+  if (!trainer->fw || trainer->fw_opened) {
+    GST_ERROR_OBJECT (trainer, "fw is not opened(%d) or fw is null(%p)",
+        trainer->fw_opened, trainer->fw);
+    return;
+  }
+
+  /* For test */
+  if (!trainer->fw->open) {     /* fw->create, create model with configuration file */
+    GST_ERROR_OBJECT (trainer, "Could not find fw->create");
+    return;
+  }
+  /* Test code, need to create with load ini file */
+  if (trainer->fw->open (&trainer->prop, &trainer->privateData) >= 0)
+    trainer->fw_created = TRUE;
+}
+
+/**
+ * @brief Find sub-plugin trainer given the name list
+ */
+static const GstTensorFilterFramework *
+gst_tensor_trainer_find_best_framework (const char *names)
+{
+  const GstTensorFilterFramework *fw = NULL; /* need to change to GstTensorTrainerFramework */
+  gchar **subplugins;
+  guint i, len;
+
+  if (names == NULL || names[0] == '\0')
+    return NULL;
+
+  subplugins = g_strsplit_set (names, " ,;", -1);
+
+  len = g_strv_length (subplugins);
+
+  for (i = 0; i < len; i++) {
+    if (strlen (g_strstrip (subplugins[i])) == 0)
+      continue;
+
+    fw = get_subplugin (NNS_SUBPLUGIN_FILTER, subplugins[i]); /* need to add trainer type to subpluginType */
+    if (fw) {
+      GST_INFO ("i = %d found %s", i, subplugins[i]);
+      break;
+    }
+  }
+  g_strfreev (subplugins);
+
+  return fw;
+}
+
+/**
+ * @brief Calculate tensor buffer size
+ */
+gsize
+gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer, guint index,
+    gboolean is_input)
+{
+  GstTensorsInfo *info;
+
+  if (is_input)
+    info = &trainer->input_meta;
+  else
+    info = &trainer->output_meta;
+
+  /* Internal Logic Error: out of bound */
+  if (index >= info->num_tensors) {
+    GST_ERROR_OBJECT (trainer, "has inconsistent data");
+    return 0;
+  }
+
+  return gst_tensor_info_get_size (&info->info[index]);
+}
+
+/**
+ * @brief Allocation initial outbuffer size
+ */
+static gboolean
+gst_tensor_trainer_transform_size (GstBaseTransform * trans,
+    GstPadDirection direction, GstCaps * caps, gsize size,
+    GstCaps * othercaps, gsize * othersize)
+{
+  GstTensorTrainer *trainer;
+
+  UNUSED (direction);
+  UNUSED (caps);
+  UNUSED (size);
+  UNUSED (othercaps);
+  trainer = GST_TENSOR_TRAINER_CAST (trans);
+
+  GST_DEBUG_OBJECT (trainer, "trainer->configured: %d", trainer->configured);
+
+  /** Internal Logic Error. Cannot proceed without configured pipeline */
+  //g_assert (trainer->configured);
+
+  *othersize = 0;
+
+  return TRUE;
+}
diff --git a/gst/nnstreamer/elements/gsttensor_trainer.h b/gst/nnstreamer/elements/gsttensor_trainer.h
new file mode 100644 (file)
index 0000000..258ebc9
--- /dev/null
@@ -0,0 +1,92 @@
+/* SPDX-License-Identifier: LGPL-2.1-only */
+/**
+ * Copyright (C) 2022 Samsung Electronics Co., Ltd.
+ *
+ * @file       gsttensor_trainer.h
+ * @date       20 October 2022
+ * @brief      GStreamer plugin to train tensor data using NN Frameworks
+ * @see                https://github.com/nnstreamer/nnstreamer
+ * @author     Hyunil Park <hyunil46.park@samsung.com>
+ * @bug                No known bugs except for NYI items
+ */
+
+#ifndef __GST_TENSOR_TRAINER_H__
+#define __GST_TENSOR_TRAINER_H__
+
+
+#include <gst/gst.h>
+#include <gst/base/gstbasetransform.h>
+#include <tensor_typedef.h>
+#include <tensor_common.h>
+
+#include <nnstreamer_plugin_api_util.h>
+#include <nnstreamer_plugin_api_filter.h>
+
+G_BEGIN_DECLS
+#define GST_TYPE_TENSOR_TRAINER \
+  (gst_tensor_trainer_get_type())
+#define GST_TENSOR_TRAINER(obj) \
+  (G_TYPE_CHECK_INSTANCE_CAST((obj),GST_TYPE_TENSOR_TRAINER,GstTensorTrainer))
+#define GST_TENSOR_TRAINER_CLASS(klass) \
+  (G_TYPE_CHECK_CLASS_CAST((klass),GST_TYPE_TENSOR_TRAINER,GstTensorTrainerClass))
+#define GST_IS_TENSOR_TRAINER(obj) \
+  (G_TYPE_CHECK_INSTANCE_TYPE((obj),GST_TYPE_TENSOR_TRAINER))
+#define GST_IS_TENSOR_TRAINER_CLASS(klass) \
+  (G_TYPE_CHECK_CLASS_TYPE((klass),GST_TYPE_TENSOR_TRAINER))
+#define GST_TENSOR_TRAINER_CAST(obj)  ((GstTensorTrainer *)(obj))
+typedef struct _GstTensorTrainer GstTensorTrainer;
+typedef struct _GstTensorTrainerClass GstTensorTrainerClass;
+
+
+/**
+ * @brief GstTensorTrainer data structure
+ */
+struct _GstTensorTrainer
+{
+  GstBaseTransform element;
+
+  gchar *fw_name;
+  gchar *input_dimensions;
+  gchar *output_dimensions;
+  gchar *input_type;
+  gchar *output_type;
+  GstTensorsInfo input_meta;
+  GstTensorsInfo output_meta;
+
+  gboolean configured;
+
+  int input_configured;
+  int output_configured;
+  int inputtype_configured;
+  int outputtype_configured;
+  unsigned int input_ranks[NNS_TENSOR_SIZE_LIMIT];
+  unsigned int output_ranks[NNS_TENSOR_SIZE_LIMIT];
+
+  /* draft */
+  int fw_opened;
+  int fw_compiled;
+  int fw_fitted;
+  int fw_created;
+  int fw_stop;
+  int fw_paused;
+
+  void *privateData; /**< NNFW plugin's private data is stored here */
+  const GstTensorFilterFramework *fw;   /* for test, need to make */
+  GstTensorFilterProperties prop; /**< NNFW plugin's properties */
+};
+
+/**
+ * @brief GstTensorTrainerClass data structure.
+ */
+struct _GstTensorTrainerClass
+{
+  GstBaseTransformClass parent_class;
+};
+
+/**
+ * @brief Get Type function required for gst elements
+ */
+GType gst_tensor_trainer_get_type (void);
+
+G_END_DECLS
+#endif /* __GST_TENSOR_TRAINER_H__ */
index 5ae4f43..e5c065b 100644 (file)
@@ -18,6 +18,7 @@ nnstreamer_sources += files(
   'gsttensor_sparseutil.c',
   'gsttensor_split.c',
   'gsttensor_transform.c'
+  'gsttensor_trainer.c'
 )
 
 # gsttensorsrc
index 1cb9155..138519b 100644 (file)
@@ -64,6 +64,7 @@
 #include <elements/gsttensor_sparseenc.h>
 #include <elements/gsttensor_split.h>
 #include <elements/gsttensor_transform.h>
+#include <elements/gsttensor_trainer.h>
 
 #ifdef _ENABLE_SRC_IIO
 #include <elements/gsttensor_srciio.h>
@@ -108,6 +109,7 @@ gst_nnstreamer_init (GstPlugin * plugin)
   NNSTREAMER_INIT (plugin, transform, TRANSFORM);
   NNSTREAMER_INIT (plugin, if, IF);
   NNSTREAMER_INIT (plugin, rate, RATE);
+  NNSTREAMER_INIT (plugin, trainer, TRAINER);
 #if defined(ENABLE_NNSTREAMER_EDGE)
   NNSTREAMER_INIT (plugin, query_serversrc, QUERY_SERVERSRC);
   NNSTREAMER_INIT (plugin, query_serversink, QUERY_SERVERSINK);