[transform] add per-channel op w/ORC
authorbyunghun <byungs286@gmail.com>
Tue, 11 Apr 2023 05:35:14 +0000 (14:35 +0900)
committerjaeyun-jung <39614140+jaeyun-jung@users.noreply.github.com>
Tue, 11 Apr 2023 08:57:43 +0000 (17:57 +0900)
add if statement to support tensor transform arithmetic orc op w/ORC

Signed-off-by: byunghun <byungs286@gmail.com>
gst/nnstreamer/elements/gsttensor_transform.c

index 023412d..db9c6ed 100644 (file)
@@ -497,6 +497,20 @@ refrain_from_heavy_op_on_float16 (gulong n)
     } \
   } while (0)
 
+#define orc_typesize(size, type) do { \
+    switch (type) { \
+      case _NNS_INT32: size = sizeof(int32_t); break; \
+      case _NNS_UINT32: size = sizeof(uint32_t); break; \
+      case _NNS_INT16: size = sizeof(int16_t); break; \
+      case _NNS_UINT16: size = sizeof(uint16_t); break; \
+      case _NNS_INT8: size = sizeof(int8_t); break; \
+      case _NNS_UINT8: size = sizeof(uint8_t); break; \
+      case _NNS_FLOAT64: size = sizeof(double); break; \
+      case _NNS_FLOAT32: size = sizeof(float); break; \
+      default: GST_ERROR_OBJECT (filter, "Unsupported type %d", type); g_assert (0); break; \
+    } \
+  } while (0)
+
 #define orc_operator_func(i,n,v,opfunc,op) do { \
     switch ((v)->type) { \
       case _NNS_INT32: opfunc (s32) ((gpointer) i, (v)->data._int32_t, n); break; \
@@ -1238,28 +1252,59 @@ gst_tensor_transform_arithmetic (GstTensorTransform * filter,
   num = gst_tensor_get_element_count (in_info->dimension);
 
 #ifdef HAVE_ORC
-  /** per-channel is not supported by orc */
-  if (!filter->data_arithmetic.per_channel_arith
-      && orc_supported (filter, in_info->type, out_info->type)) {
+  if (orc_supported (filter, in_info->type, out_info->type)) {
     walk = filter->operators;
-
     /**
      * Typecast should be called at the first.
      * Do the typecast. If in/out type is same, this will copy the input array to output.
      */
     orc_typecast (inptr, outptr, num, in_info->type, out_info->type);
 
-    while (walk) {
-      op_s = (tensor_transform_operator_s *) walk->data;
+    if (!filter->data_arithmetic.per_channel_arith) {
+      while (walk) {
+        op_s = (tensor_transform_operator_s *) walk->data;
 
-      if (op_s->op != GTT_OP_TYPECAST) {
-        gst_tensor_data_typecast (&op_s->value, out_info->type);
-        orc_operator (outptr, num, &op_s->value, op_s->op);
+        if (op_s->op != GTT_OP_TYPECAST) {
+          gst_tensor_data_typecast (&op_s->value, out_info->type);
+          orc_operator (outptr, num, &op_s->value, op_s->op);
+        }
+
+        walk = g_slist_next (walk);
       }
+    } else {
+      gsize typesize = 0;
+      guint ch_dim = filter->data_arithmetic.ch_dim;
+      gsize ch_offset, ch_size = 1;
+      uint8_t *tmp_outptr = NULL;
 
-      walk = g_slist_next (walk);
-    }
+      for (i = 0; i < ch_dim; ++i) {
+        ch_size *= in_info->dimension[i];
+      }
+      ch_offset = ch_size * in_info->dimension[ch_dim];
+      orc_typesize (typesize, out_info->type);
+
+      while (walk) {
+        op_s = (tensor_transform_operator_s *) walk->data;
+        if (op_s->op == GTT_OP_TYPECAST) {
+          walk = g_slist_next (walk);
+          continue;
+        }
 
+        if (op_s->applying_ch == -1) {
+          gst_tensor_data_typecast (&op_s->value, out_info->type);
+          orc_operator (outptr, num, &op_s->value, op_s->op);
+        } else {
+          for (i = 0; i < num / ch_offset; ++i) {
+            tmp_outptr =
+                outptr + (ch_size * op_s->applying_ch +
+                ch_offset * i) * typesize;
+            gst_tensor_data_typecast (&op_s->value, out_info->type);
+            orc_operator (tmp_outptr, ch_size, &op_s->value, op_s->op);
+          }
+        }
+        walk = g_slist_next (walk);
+      }
+    }
     return GST_FLOW_OK;
   }
 #endif