[Filter] condition to update dimension
authorJaeyun Jung <jy1210.jung@samsung.com>
Mon, 4 Dec 2023 08:22:59 +0000 (17:22 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Tue, 12 Dec 2023 06:36:53 +0000 (15:36 +0900)
For caps negotiation, compare tensor dimension if caps is tensor stream and already configured.

Signed-off-by: Jaeyun Jung <jy1210.jung@samsung.com>
gst/nnstreamer/include/nnstreamer_plugin_api.h
gst/nnstreamer/nnstreamer_plugin_api_impl.c
gst/nnstreamer/tensor_filter/tensor_filter.c
tests/nnstreamer_plugins/unittest_plugins.cc

index e75af23..bb6e947 100644 (file)
@@ -104,10 +104,10 @@ gst_tensor_meta_info_append_header (GstTensorMetaInfo * meta, GstMemory * mem);
 /**
  * @brief Update caps dimension for negotiation
  * @param caps caps to compare and update
- * @param peer_caps caps to compare
+ * @param filter caps to compare
  */
 extern void
-gst_tensor_caps_update_dimension (GstCaps *caps, GstCaps *peer_caps);
+gst_tensor_caps_update_dimension (GstCaps *caps, GstCaps *filter);
 
 /**
  * @brief  Try intersecting @caps1 and @caps2 for tensor stream
index 3d5a1d6..c566b95 100644 (file)
@@ -1043,42 +1043,45 @@ _is_structure_dimension_same (GstStructure * st1, GstStructure * st2,
 /**
  * @brief Update caps dimensions for negotiation
  * @param caps caps to compare and update
- * @param peer_caps caps to compare
+ * @param filter caps to compare
  */
 void
-gst_tensor_caps_update_dimension (GstCaps * caps, GstCaps * peer_caps)
+gst_tensor_caps_update_dimension (GstCaps * caps, GstCaps * filter)
 {
-  GstStructure *structure;
-  GstStructure *structure_peer;
+  GstStructure *st_caps, *st_filter;
   guint i, j;
 
-  g_return_if_fail (caps != NULL);
-  g_return_if_fail (peer_caps != NULL);
+  g_return_if_fail (GST_IS_CAPS (caps));
+  g_return_if_fail (GST_IS_CAPS (filter));
 
   for (i = 0; i < gst_caps_get_size (caps); i++) {
-    structure = gst_caps_get_structure (caps, i);
+    st_caps = gst_caps_get_structure (caps, i);
+
+    if (!gst_structure_is_tensor_stream (st_caps))
+      continue;
+
+    for (j = 0; j < gst_caps_get_size (filter); j++) {
+      st_filter = gst_caps_get_structure (filter, j);
 
-    for (j = 0; j < gst_caps_get_size (peer_caps); j++) {
-      structure_peer = gst_caps_get_structure (peer_caps, j);
+      if (!gst_structure_is_tensor_stream (st_filter))
+        continue;
 
       /* other/tensor */
-      if (gst_structure_has_field (structure, "dimension")
-          && gst_structure_has_field (structure_peer, "dimension")) {
+      if (gst_structure_has_field (st_caps, "dimension")
+          && gst_structure_has_field (st_filter, "dimension")) {
         /* update dimensions for negotiation */
-        if (_is_structure_dimension_same (structure, structure_peer,
-                "dimension")) {
-          gst_structure_set (structure, "dimension", G_TYPE_STRING,
-              gst_structure_get_string (structure_peer, "dimension"), NULL);
+        if (_is_structure_dimension_same (st_caps, st_filter, "dimension")) {
+          gst_structure_set (st_caps, "dimension", G_TYPE_STRING,
+              gst_structure_get_string (st_filter, "dimension"), NULL);
         }
       }
       /* other/tensors */
-      else if (gst_structure_has_field (structure, "dimensions")
-          && gst_structure_has_field (structure_peer, "dimensions")) {
+      else if (gst_structure_has_field (st_caps, "dimensions")
+          && gst_structure_has_field (st_filter, "dimensions")) {
         /* update dimensions for negotiation */
-        if (_is_structure_dimension_same (structure, structure_peer,
-                "dimensions")) {
-          gst_structure_set (structure, "dimensions", G_TYPE_STRING,
-              gst_structure_get_string (structure_peer, "dimensions"), NULL);
+        if (_is_structure_dimension_same (st_caps, st_filter, "dimensions")) {
+          gst_structure_set (st_caps, "dimensions", G_TYPE_STRING,
+              gst_structure_get_string (st_filter, "dimensions"), NULL);
         }
       }
     }
index 48534b2..d4a0f5d 100644 (file)
@@ -1157,7 +1157,6 @@ gst_tensor_filter_transform_caps (GstBaseTransform * trans,
   GstTensorsConfig in_config, out_config;
   GstPad *pad;
   GstCaps *result;
-  GstCaps *peer_caps;
   GstStructure *structure;
   gboolean configured = FALSE;
 
@@ -1231,19 +1230,22 @@ gst_tensor_filter_transform_caps (GstBaseTransform * trans,
   if (configured) {
     /* output info may be configured */
     result = gst_tensor_pad_possible_caps_from_config (pad, &out_config);
+
+    /* Update dimension for src pad caps. */
+    if (direction == GST_PAD_SINK) {
+      GstCaps *peer = gst_pad_peer_query_caps (pad, NULL);
+
+      if (peer) {
+        if (!gst_caps_is_any (peer))
+          gst_tensor_caps_update_dimension (result, peer);
+        gst_caps_unref (peer);
+      }
+    }
   } else {
     /* we don't know the exact tensor info yet */
     result = gst_caps_from_string (CAPS_STRING);
   }
 
-  /* Update caps dimension for src pad cap */
-  if (direction == GST_PAD_SINK) {
-    if ((peer_caps = gst_pad_peer_query_caps (pad, NULL))) {
-      gst_tensor_caps_update_dimension (result, peer_caps);
-      gst_caps_unref (peer_caps);
-    }
-  }
-
   if (filter && gst_caps_get_size (filter) > 0) {
     GstCaps *intersection;
 
index 5e658b0..ed446a4 100644 (file)
@@ -5143,8 +5143,15 @@ TEST_REQUIRE_TFLITE (testTensorTransform, negotiationFilter)
       "test_models", "models", "mobilenet_v1_1.0_224_quant.tflite", NULL);
   ASSERT_TRUE (g_file_test (test_model, G_FILE_TEST_EXISTS));
 
+  /**
+   * tensor-filter information
+   * input type uint8 dimension 3:224:224:1
+   * output type uint8 dimension 1001:1
+   */
   g_autofree gchar *pipeline = g_strdup_printf (
-      "tensor_transform mode=typecast option=uint8 ! tensor_filter framework=tensorflow-lite model=%s",
+      "tensor_transform mode=typecast option=uint8 ! tensor_filter framework=tensorflow-lite model=%s ! "
+      "other/tensors,num_tensors=1,dimensions=(string)\"1001:1:1:1:1\" ! "
+      "tensor_transform mode=typecast option=int8",
       test_model);
 
   h = gst_harness_new_parse (pipeline);