[Transform] refactor arith/typecast
authorJaeyun <jy1210.jung@samsung.com>
Wed, 14 Nov 2018 09:31:18 +0000 (18:31 +0900)
committerMyungJoo Ham <myungjoo.ham@gmail.com>
Mon, 19 Nov 2018 07:58:27 +0000 (07:58 +0000)
refactor arith/typecast to handle arithmetic in sequence

1. add operands and handle tensor element with parsed option
2. add common functions for the operators
3. remove multi-op (add-mul and mul-add) and handle operator in sequence

Signed-off-by: Jaeyun Jung <jy1210.jung@samsung.com>
gst/tensor_transform/tensor_transform.c
gst/tensor_transform/tensor_transform.h
tests/transform_arithmetic/runTest.sh

index 22fe903..4150d41 100644 (file)
@@ -109,12 +109,12 @@ static const gchar *gst_tensor_transform_stand_string[] = {
   [STAND_END] = "error"
 };
 
-static const gchar *gst_tensor_transform_arithmetic_string[] = {
-  [ARITH_ADD] = "add",
-  [ARITH_MUL] = "mul",
-  [ARITH_ADD_MUL] = "add-mul",
-  [ARITH_MUL_ADD] = "mul-add",
-  [ARITH_END] = "error"
+static const gchar *gst_tensor_transform_operator_string[] = {
+  [GTT_OP_TYPECAST] = "typecast",
+  [GTT_OP_ADD] = "add",
+  [GTT_OP_MUL] = "mul",
+  [GTT_OP_DIV] = "div",
+  [GTT_OP_UNKNOWN] = "unknown"
 };
 
 /**
@@ -226,24 +226,25 @@ gst_tensor_transform_init (GstTensorTransform * filter)
   filter->mode = GTT_END;
   filter->option = NULL;
   filter->loaded = FALSE;
+  filter->operators = NULL;
 
   gst_tensor_config_init (&filter->in_config);
   gst_tensor_config_init (&filter->out_config);
 }
 
 /**
- * @brief Get the corresponding mode from the string value
- * @param[in] str The string value for the mode
- * @return corresponding mode for the string. ARITH_END for errors
+ * @brief Get the corresponding operator from the string value
+ * @param[in] str The string value for the operator
+ * @return corresponding operator for the string (GTT_OP_UNKNOWN for errors)
  */
-static tensor_transform_arith_mode
-gst_tensor_transform_get_arith_mode (const gchar * str)
+static tensor_transform_operator
+gst_tensor_transform_get_operator (const gchar * str)
 {
   int index;
 
-  index = find_key_strv (gst_tensor_transform_arithmetic_string, str);
+  index = find_key_strv (gst_tensor_transform_operator_string, str);
 
-  return (index < 0) ? ARITH_END : index;
+  return (index < 0) ? GTT_OP_UNKNOWN : index;
 }
 
 /**
@@ -277,6 +278,315 @@ gst_tensor_transform_get_mode (const gchar * str)
 }
 
 /**
+ * @brief Macro to set operand
+ */
+#define set_operand_value(v,d,vtype) do { \
+    (v)->data._##vtype = *((vtype *) d); \
+  } while (0)
+
+/**
+ * @brief Set tensor element value with given type
+ * @param filter "this" pointer
+ * @param value struct for operand of arith mode
+ * @param type tensor type
+ * @param data pointer of tensor element value
+ * @return TRUE if no error
+ */
+static gboolean
+gst_tensor_transform_set_value (GstTensorTransform * filter,
+    tensor_transform_operand_s * value, tensor_type type, gpointer data)
+{
+  g_return_val_if_fail (value != NULL, FALSE);
+  g_return_val_if_fail (data != NULL, FALSE);
+
+  /* init tensor value */
+  memset (value, 0, sizeof (tensor_transform_operand_s));
+  value->type = _NNS_END;
+
+  switch (type) {
+    case _NNS_INT32:
+      set_operand_value (value, data, int32_t);
+      break;
+    case _NNS_UINT32:
+      set_operand_value (value, data, uint32_t);
+      break;
+    case _NNS_INT16:
+      set_operand_value (value, data, int16_t);
+      break;
+    case _NNS_UINT16:
+      set_operand_value (value, data, uint16_t);
+      break;
+    case _NNS_INT8:
+      set_operand_value (value, data, int8_t);
+      break;
+    case _NNS_UINT8:
+      set_operand_value (value, data, uint8_t);
+      break;
+    case _NNS_FLOAT64:
+      set_operand_value (value, data, double);
+      break;
+    case _NNS_FLOAT32:
+      set_operand_value (value, data, float);
+      break;
+    case _NNS_INT64:
+      set_operand_value (value, data, int64_t);
+      break;
+    case _NNS_UINT64:
+      set_operand_value (value, data, uint64_t);
+      break;
+    default:
+      GST_ERROR_OBJECT (filter, "Unknown tensor type %d", type);
+      return FALSE;
+  }
+
+  value->type = type;
+  return TRUE;
+}
+
+/**
+ * @brief Macro to get operand
+ */
+#define get_operand_value(v,d,vtype) do { \
+    *((vtype *) d) = (v)->data._##vtype; \
+  } while (0)
+
+/**
+ * @brief Get tensor element value with given type
+ * @param filter "this" pointer
+ * @param value struct for operand of arith mode
+ * @param data pointer of tensor element value
+ * @return TRUE if no error
+ */
+static gboolean
+gst_tensor_transform_get_value (GstTensorTransform * filter,
+    tensor_transform_operand_s * value, gpointer data)
+{
+  g_return_val_if_fail (value != NULL, FALSE);
+  g_return_val_if_fail (data != NULL, FALSE);
+
+  switch (value->type) {
+    case _NNS_INT32:
+      get_operand_value (value, data, int32_t);
+      break;
+    case _NNS_UINT32:
+      get_operand_value (value, data, uint32_t);
+      break;
+    case _NNS_INT16:
+      get_operand_value (value, data, int16_t);
+      break;
+    case _NNS_UINT16:
+      get_operand_value (value, data, uint16_t);
+      break;
+    case _NNS_INT8:
+      get_operand_value (value, data, int8_t);
+      break;
+    case _NNS_UINT8:
+      get_operand_value (value, data, uint8_t);
+      break;
+    case _NNS_FLOAT64:
+      get_operand_value (value, data, double);
+      break;
+    case _NNS_FLOAT32:
+      get_operand_value (value, data, float);
+      break;
+    case _NNS_INT64:
+      get_operand_value (value, data, int64_t);
+      break;
+    case _NNS_UINT64:
+      get_operand_value (value, data, uint64_t);
+      break;
+    default:
+      GST_ERROR_OBJECT (filter, "Unknown tensor type %d", value->type);
+      return FALSE;
+  }
+
+  return TRUE;
+}
+
+/**
+ * @brief Macro for operator
+ */
+#define handle_operator(d,v,oper,vtype) do { \
+    switch (oper) { \
+      case GTT_OP_ADD: \
+        (d)->data._##vtype += (v)->data._##vtype; \
+        break; \
+      case GTT_OP_MUL: \
+        (d)->data._##vtype *= (v)->data._##vtype; \
+        break; \
+      case GTT_OP_DIV: \
+        if ((v)->data._##vtype == 0) { \
+          GST_ERROR_OBJECT (filter, "Invalid state, denominator is 0."); \
+          return FALSE; \
+        } \
+        (d)->data._##vtype /= (v)->data._##vtype; \
+        break; \
+      default: \
+        GST_ERROR_OBJECT (filter, "Unknown operator %d", oper); \
+        return FALSE; \
+    } \
+  } while (0)
+
+/**
+ * @brief Handle operators for tensor value
+ * @param filter "this" pointer
+ * @param desc struct for tensor value
+ * @param val struct for tensor value
+ * @param op operator for given tensor value
+ * @return TRUE if no error
+ */
+static gboolean
+gst_tensor_transform_do_operator (GstTensorTransform * filter,
+    tensor_transform_operand_s * desc, const tensor_transform_operand_s * val,
+    tensor_transform_operator op)
+{
+  g_return_val_if_fail (desc != NULL, FALSE);
+  g_return_val_if_fail (val != NULL, FALSE);
+  g_return_val_if_fail (desc->type == val->type, FALSE);
+
+  switch (desc->type) {
+    case _NNS_INT32:
+      handle_operator (desc, val, op, int32_t);
+      break;
+    case _NNS_UINT32:
+      handle_operator (desc, val, op, uint32_t);
+      break;
+    case _NNS_INT16:
+      handle_operator (desc, val, op, int16_t);
+      break;
+    case _NNS_UINT16:
+      handle_operator (desc, val, op, uint16_t);
+      break;
+    case _NNS_INT8:
+      handle_operator (desc, val, op, int8_t);
+      break;
+    case _NNS_UINT8:
+      handle_operator (desc, val, op, uint8_t);
+      break;
+    case _NNS_FLOAT64:
+      handle_operator (desc, val, op, double);
+      break;
+    case _NNS_FLOAT32:
+      handle_operator (desc, val, op, float);
+      break;
+    case _NNS_INT64:
+      handle_operator (desc, val, op, int64_t);
+      break;
+    case _NNS_UINT64:
+      handle_operator (desc, val, op, uint64_t);
+      break;
+    default:
+      GST_ERROR_OBJECT (filter, "Unknown tensor type %d", desc->type);
+      return FALSE;
+  }
+
+  return TRUE;
+}
+
+/**
+ * @brief Macro for typecast
+ */
+#define typecast_value_to(v,itype,otype) do { \
+    itype in_val = (v)->data._##itype; \
+    otype out_val = (otype) in_val; \
+    (v)->data._##otype = out_val; \
+  } while (0)
+
+#define typecast_value(v,otype) do { \
+    switch ((v)->type) { \
+      case _NNS_INT32: typecast_value_to (v, int32_t, otype); break; \
+      case _NNS_UINT32: typecast_value_to (v, uint32_t, otype); break; \
+      case _NNS_INT16: typecast_value_to (v, int16_t, otype); break; \
+      case _NNS_UINT16:  typecast_value_to (v, uint16_t, otype); break; \
+      case _NNS_INT8: typecast_value_to (v, int8_t, otype); break; \
+      case _NNS_UINT8: typecast_value_to (v, uint8_t, otype); break; \
+      case _NNS_FLOAT64: typecast_value_to (v, double, otype); break; \
+      case _NNS_FLOAT32: typecast_value_to (v, float, otype); break; \
+      case _NNS_INT64: typecast_value_to (v, int64_t, otype); break; \
+      case _NNS_UINT64: typecast_value_to (v, uint64_t, otype); break; \
+      default: g_assert (0); break; \
+    } \
+  } while (0)
+
+/**
+ * @brief Typecast tensor element value
+ * @param filter "this" pointer
+ * @param value struct for operand of arith mode
+ * @param type tensor type to be transformed
+ * @return TRUE if no error
+ */
+static gboolean
+gst_tensor_transform_typecast_value (GstTensorTransform * filter,
+    tensor_transform_operand_s * value, tensor_type type)
+{
+  gboolean is_float;
+
+  g_return_val_if_fail (value != NULL, FALSE);
+  g_return_val_if_fail (type != _NNS_END, FALSE);
+
+  /* do nothing when transform to same type */
+  if (value->type != type) {
+    is_float = (type == _NNS_FLOAT32 || type == _NNS_FLOAT64);
+
+    switch (type) {
+      case _NNS_INT32:
+        typecast_value (value, int32_t);
+        break;
+      case _NNS_UINT32:
+        if (is_float) {
+          /* int32 -> uint32 */
+          typecast_value (value, int32_t);
+        }
+        typecast_value (value, uint32_t);
+        break;
+      case _NNS_INT16:
+        typecast_value (value, int16_t);
+        break;
+      case _NNS_UINT16:
+        if (is_float) {
+          /* int16 -> uint16 */
+          typecast_value (value, int16_t);
+        }
+        typecast_value (value, uint16_t);
+        break;
+      case _NNS_INT8:
+        typecast_value (value, int8_t);
+        break;
+      case _NNS_UINT8:
+        if (is_float) {
+          /* int8 -> uint8 */
+          typecast_value (value, int8_t);
+        }
+        typecast_value (value, uint8_t);
+        break;
+      case _NNS_FLOAT64:
+        typecast_value (value, double);
+        break;
+      case _NNS_FLOAT32:
+        typecast_value (value, float);
+        break;
+      case _NNS_INT64:
+        typecast_value (value, int64_t);
+        break;
+      case _NNS_UINT64:
+        if (is_float) {
+          /* int64 -> uint64 */
+          typecast_value (value, int64_t);
+        }
+        typecast_value (value, uint64_t);
+        break;
+      default:
+        GST_ERROR_OBJECT (filter, "Unknown tensor type %d", type);
+        return FALSE;
+    }
+
+    value->type = type;
+  }
+
+  return TRUE;
+}
+
+/**
  * @brief Setup internal data (data_* in GstTensorTransform)
  * @param[in/out] filter "this" pointer. mode & option MUST BE set already.
  */
@@ -317,46 +627,89 @@ gst_tensor_transform_set_option_data (GstTensorTransform * filter)
     }
     case GTT_ARITHMETIC:
     {
-      gchar **strv = g_strsplit (filter->option, ":", 2);
-
-      if (strv[0] != NULL) {
-        filter->data_arithmetic.mode =
-            gst_tensor_transform_get_arith_mode (strv[0]);
-        g_assert (filter->data_arithmetic.mode != ARITH_END);
-      }
-
-      if (strv[1] != NULL) {
-        gchar **operands = g_strsplit (strv[1], ":", 2);
-        gchar *not_consumed;
-        int i;
-
-        for (i = 0; i < ARITH_OPRND_NUM_LIMIT; ++i) {
-          filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_END;
-          if ((operands[i] != NULL) && (strlen (operands[i]) != 0)) {
-            if (strchr (operands[i], '.') || strchr (operands[i], 'e') ||
-                strchr (operands[i], 'E')) {
-              filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_DOUBLE;
-              filter->data_arithmetic.value[i].value_double =
-                  g_ascii_strtod (operands[i], &not_consumed);
-            } else {
-              filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_INT64;
-              filter->data_arithmetic.value[i].value_int64 =
-                  g_ascii_strtoll (operands[i], &not_consumed, 10);
-            }
+      gchar **str_operators;
+      gchar **str_op;
+      tensor_transform_operator_s *op_s;
+      guint i, num_operators, num_op;
+
+      filter->data_arithmetic.out_type = _NNS_END;
+
+      str_operators = g_strsplit (filter->option, ",", -1);
+      num_operators = g_strv_length (str_operators);
+
+      for (i = 0; i < num_operators; ++i) {
+        str_op = g_strsplit (str_operators[i], ":", -1);
+        num_op = g_strv_length (str_op);
+
+        if (str_op[0]) {
+          op_s = g_new0 (tensor_transform_operator_s, 1);
+          g_assert (op_s);
+
+          op_s->op = gst_tensor_transform_get_operator (str_op[0]);
+
+          switch (op_s->op) {
+            case GTT_OP_TYPECAST:
+              if (num_op > 1 && str_op[1]) {
+                op_s->value.type = get_tensor_type (str_op[1]);
+
+                if (op_s->value.type == _NNS_END) {
+                  GST_WARNING_OBJECT (filter, "Unknown tensor type %s",
+                      str_op[1]);
+                  op_s->op = GTT_OP_UNKNOWN;
+                } else {
+                  filter->data_arithmetic.out_type = op_s->value.type;
+                }
+              } else {
+                GST_WARNING_OBJECT (filter, "Invalid option for typecast %s",
+                    str_operators[i]);
+                op_s->op = GTT_OP_UNKNOWN;
+              }
+              break;
+            case GTT_OP_ADD:
+            case GTT_OP_MUL:
+            case GTT_OP_DIV:
+              if (num_op > 1 && str_op[1]) {
+                /* get operand */
+                if (strchr (str_op[1], '.') || strchr (str_op[1], 'e') ||
+                    strchr (str_op[1], 'E')) {
+                  double val;
+
+                  val = g_ascii_strtod (str_op[1], NULL);
+                  gst_tensor_transform_set_value (filter, &op_s->value,
+                      _NNS_FLOAT64, &val);
+                } else {
+                  int64_t val;
+
+                  val = g_ascii_strtoll (str_op[1], NULL, 10);
+                  gst_tensor_transform_set_value (filter, &op_s->value,
+                      _NNS_INT64, &val);
+                }
+              } else {
+                GST_WARNING_OBJECT (filter, "Invalid option for arithmetic %s",
+                    str_operators[i]);
+                op_s->op = GTT_OP_UNKNOWN;
+              }
+              break;
+            default:
+              GST_WARNING_OBJECT (filter, "Unknown operator %s", str_op[0]);
+              break;
+          }
 
-            if (strlen (not_consumed)) {
-              g_printerr ("%s is not a valid integer or floating point value\n",
-                  operands[i]);
-              g_assert (0);
-            }
+          /* append operator */
+          if (op_s->op != GTT_OP_UNKNOWN) {
+            filter->operators = g_slist_append (filter->operators, op_s);
+          } else {
+            g_free (op_s);
           }
+        } else {
+          GST_WARNING_OBJECT (filter, "Invalid option %s", str_operators[i]);
         }
 
-        g_strfreev (operands);
+        g_strfreev (str_op);
       }
 
-      filter->loaded = TRUE;
-      g_strfreev (strv);
+      filter->loaded = (filter->operators != NULL);
+      g_strfreev (str_operators);
       break;
     }
     case GTT_TRANSPOSE:
@@ -463,6 +816,11 @@ gst_tensor_transform_finalize (GObject * object)
     filter->option = NULL;
   }
 
+  if (filter->operators) {
+    g_slist_free_full (filter->operators, g_free);
+    filter->operators = NULL;
+  }
+
   G_OBJECT_CLASS (parent_class)->finalize (object);
 }
 
@@ -545,77 +903,6 @@ gst_tensor_transform_dimchg (GstTensorTransform * filter,
 }
 
 /**
- * Macro to run loop for various data types with simple cast
- */
-#define castloop(itype,otype,num) do { \
-    otype *ptr = (otype *) outptr; \
-    itype *iptr = (itype *) inptr; \
-    size_t i; \
-    for (i = 0; i < num; i++) { \
-      *(ptr + i) = (otype) *(iptr + i); \
-    } \
-  } while (0)
-
-/**
- * Macro to run loop for various data types with simple cast
- * While castloop directly casts itype to otype, this macro indirectly casts
- * itype to otype using mtype as an intermediate
- */
-#define castloop_via_intermediate(itype, mtype, otype, num) do { \
-    otype *ptr = (otype *) outptr; \
-    itype *iptr = (itype *) inptr; \
-    size_t i; \
-    for (i = 0; i < num; i++) { \
-      mtype m = (mtype) *(iptr + i);\
-      *(ptr + i) = (otype) m; \
-    } \
-  } while (0)
-
-/**
- * Macro to run loop for various data types with a converter function
- */
-#define convloop(itype,otype,num,convfunc) do { \
-    otype *ptr = (otype *) outptr; \
-    itype *iptr = (itype *) inptr; \
-    size_t i; \
-    for (i = 0; i < num; i++) { \
-      *(ptr + i) = convfunc(iptr + i); \
-    } \
-  } while (0)
-
-/**
- * Macro to unburden switch cases with castloop/convloop (per itype)
- * This is for cases otype is numeral.
- */
-#define numotype_castloop_per_itype(otype,num) do { \
-    switch (in_tensor_type) { \
-    case _NNS_INT8: castloop(int8_t, otype, num); break; \
-    case _NNS_INT16: castloop(int16_t, otype, num); break; \
-    case _NNS_INT32: castloop(int32_t, otype, num); break; \
-    case _NNS_UINT8: castloop(uint8_t, otype, num); break; \
-    case _NNS_UINT16: castloop(uint16_t, otype, num); break; \
-    case _NNS_UINT32: castloop(uint32_t, otype, num); break; \
-    case _NNS_FLOAT32: castloop(float, otype, num); break; \
-    case _NNS_FLOAT64: castloop(double, otype, num); break; \
-    case _NNS_INT64: castloop(int64_t, otype, num); break; \
-    case _NNS_UINT64: castloop(uint64_t, otype, num); break; \
-    default: g_assert(0); return GST_FLOW_ERROR; \
-    } \
-  } while (0)
-
-#define numotype_castloop_via_intermediate_for_float_itype(mtype, otype, num) do { \
-    switch (in_tensor_type) { \
-     case _NNS_FLOAT32:\
-      castloop_via_intermediate(float, mtype, otype, num); \
-      break; \
-    case _NNS_FLOAT64: \
-      castloop_via_intermediate(double, mtype, otype, num); \
-      break; \
-    default: g_assert(0); \
-    } \
-  } while (0)
-
-/**
  * @brief subrouting for tensor-tranform, "typecast" case.
  * @param[in/out] filter "this" pointer
  * @param[in] inptr input tensor
@@ -626,162 +913,32 @@ static GstFlowReturn
 gst_tensor_transform_typecast (GstTensorTransform * filter,
     const uint8_t * inptr, uint8_t * outptr)
 {
-  uint32_t num = get_tensor_element_count (filter->in_config.info.dimension);
+  size_t num = get_tensor_element_count (filter->in_config.info.dimension);
   tensor_type in_tensor_type = filter->in_config.info.type;
+  tensor_type out_tensor_type = filter->out_config.info.type;
 
-  switch (filter->data_typecast.to) {
-    case _NNS_INT8:
-      numotype_castloop_per_itype (int8_t, num);
-      break;
-    case _NNS_INT16:
-      numotype_castloop_per_itype (int16_t, num);
-      break;
-    case _NNS_INT32:
-      numotype_castloop_per_itype (int32_t, num);
-      break;
-    case _NNS_UINT8:
-      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
-        numotype_castloop_via_intermediate_for_float_itype (int8_t, uint8_t,
-            num);
-      } else {
-        numotype_castloop_per_itype (uint8_t, num);
-      }
-      break;
-    case _NNS_UINT16:
-      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
-        numotype_castloop_via_intermediate_for_float_itype (int16_t, uint16_t,
-            num);
-      } else {
-        numotype_castloop_per_itype (uint16_t, num);
-      }
-      break;
-    case _NNS_UINT32:
-      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
-        numotype_castloop_via_intermediate_for_float_itype (int32_t, uint32_t,
-            num);
-      } else {
-        numotype_castloop_per_itype (uint32_t, num);
-      }
-      break;
-    case _NNS_FLOAT32:
-      numotype_castloop_per_itype (float, num);
-      break;
-    case _NNS_FLOAT64:
-      numotype_castloop_per_itype (double, num);
-      break;
-    case _NNS_INT64:
-      numotype_castloop_per_itype (int64_t, num);
-      break;
-    case _NNS_UINT64:
-      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
-        numotype_castloop_via_intermediate_for_float_itype (int64_t, uint64_t,
-            num);
-      } else {
-        numotype_castloop_per_itype (uint64_t, num);
-      }
-      break;
-    default:
-      g_assert (0);
-      return GST_FLOW_ERROR;
-  }
-
-  return GST_FLOW_OK;
-}
+  tensor_transform_operand_s value;
+  size_t i, data_idx;
 
-/**
- * Macro to run loop for various data types with simple arithmetic which has single operand
- */
-#define arith(itype,num,op,a) do { \
-    size_t i; \
-    itype *in = (itype *) inptr; \
-    itype *out = (itype *) outptr; \
-    for (i=0;i<num;i++){ \
-      *(out+i) = (*(in+i) op a); \
-    } \
-  }while(0);
+  for (i = 0; i < num; ++i) {
+    /* init value with input tensor type */
+    data_idx = tensor_element_size[in_tensor_type] * i;
+    gst_tensor_transform_set_value (filter, &value, in_tensor_type,
+        (gpointer) (inptr + data_idx));
 
-/**
- * Macro to run loop for various data types with simple arithmetic which has dual operands
- */
-#define arith2(itype,num,op1,a,op2,b) do { \
-    size_t i; \
-    itype *in = (itype *) inptr; \
-    itype *out = (itype *) outptr; \
-    for (i=0;i<num;i++){ \
-      *(out+i) = (*(in+i) op1 a) op2 b; \
-    } \
-  }while(0);
-
-/**
- * Macro to handle the case of single operand
- */
-#define arithmode_single_oprnd_case(itype,num,mode,op,value) do { \
-  itype a;\
-  switch (value[0].type) {\
-  case ARITH_OPRND_TYPE_INT64 : a = (itype) value[0].value_int64; break; \
-  case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value[0].value_double; break;\
-  default: \
-  g_printerr ("The operand required by \'%s\' is not properly provided.\n", \
-      gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\
-  g_assert(0); \
-  }; \
-  arith(itype, num, op, a); break; \
-} while (0);
-
-/**
- * Macro to handle the case of dual operands
- */
-#define arithmode_dual_oprnd_case(itype,num,mode,op1,op2,value) \
-do {\
-  itype a;\
-  itype b; \
-  switch (value[0].type) {\
-  case ARITH_OPRND_TYPE_INT64 : a = (itype) value[0].value_int64; break; \
-  case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value[0].value_double; break;\
-  default: \
-  g_printerr ("The operands required by \'%s\' are not properly provided.\n", \
-      gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\
-  g_assert(0); \
-  }; \
-  switch (value[1].type) {\
-  case ARITH_OPRND_TYPE_INT64 : b = (itype) value[1].value_int64; break; \
-  case ARITH_OPRND_TYPE_DOUBLE : b = (itype) value[1].value_double; break;\
-  default: \
-  g_printerr ("The operands required by \'%s\' are not properly provided.\n", \
-      gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\
-  g_assert(0); \
-  }; \
-  arith2(itype, num, op1, a, op2, b); break; \
-} while (0);
+    /* typecast */
+    gst_tensor_transform_typecast_value (filter, &value, out_tensor_type);
 
-/**
- * Macro to run loop for various data types with simple arithmetic
- */
-#define arithloopcase(typecase,itype,num,mode,value) \
-  case typecase: \
-  { \
-    switch (mode) { \
-    case ARITH_ADD: {\
-      arithmode_single_oprnd_case (itype, num, mode, +, value); \
-      break; \
-    }; \
-    case ARITH_MUL: { \
-      arithmode_single_oprnd_case (itype, num, mode, *, value); \
-      break; \
-    };\
-    case ARITH_ADD_MUL: {\
-      arithmode_dual_oprnd_case (itype, num, mode, +, *, value); \
-      break; \
-    }; \
-    case ARITH_MUL_ADD: {\
-      arithmode_dual_oprnd_case (itype, num, mode, *, +, value); \
-      break; \
-    }; \
-    default: g_assert(0); return GST_FLOW_ERROR; \
-    } \
-    break; \
+    /* set output value */
+    g_assert (out_tensor_type == value.type);
+    data_idx = tensor_element_size[out_tensor_type] * i;
+    gst_tensor_transform_get_value (filter, &value,
+        (gpointer) (outptr + data_idx));
   }
 
+  return GST_FLOW_OK;
+}
+
 /**
  * @brief subrouting for tensor-tranform, "arithmetic" case.
  * @param[in/out] filter "this" pointer
@@ -793,25 +950,64 @@ static GstFlowReturn
 gst_tensor_transform_arithmetic (GstTensorTransform * filter,
     const uint8_t * inptr, uint8_t * outptr)
 {
-  uint32_t num = get_tensor_element_count (filter->in_config.info.dimension);
+  size_t num = get_tensor_element_count (filter->in_config.info.dimension);
   tensor_type in_tensor_type = filter->in_config.info.type;
-  tensor_transform_arith_mode mode = filter->data_arithmetic.mode;
-  tensor_transform_arithmetic_operand *value = filter->data_arithmetic.value;
-
-  switch (in_tensor_type) {
-      arithloopcase (_NNS_INT8, int8_t, num, mode, value);
-      arithloopcase (_NNS_INT16, int16_t, num, mode, value);
-      arithloopcase (_NNS_INT32, int32_t, num, mode, value);
-      arithloopcase (_NNS_UINT8, uint8_t, num, mode, value);
-      arithloopcase (_NNS_UINT16, uint16_t, num, mode, value);
-      arithloopcase (_NNS_UINT32, uint32_t, num, mode, value);
-      arithloopcase (_NNS_FLOAT32, float, num, mode, value);
-      arithloopcase (_NNS_FLOAT64, double, num, mode, value);
-      arithloopcase (_NNS_INT64, int64_t, num, mode, value);
-      arithloopcase (_NNS_UINT64, uint64_t, num, mode, value);
-    default:
-      g_assert (0);
-      return GST_FLOW_ERROR;
+  tensor_type out_tensor_type = filter->out_config.info.type;
+
+  GSList *walk;
+  tensor_transform_operator_s *op_s;
+  tensor_transform_operand_s value;
+  size_t i, data_idx;
+
+  for (i = 0; i < num; ++i) {
+    /* init value with input tensor type */
+    data_idx = tensor_element_size[in_tensor_type] * i;
+    gst_tensor_transform_set_value (filter, &value, in_tensor_type,
+        (gpointer) (inptr + data_idx));
+
+    walk = filter->operators;
+    while (walk) {
+      op_s = (tensor_transform_operator_s *) walk->data;
+
+      /**
+       * @todo add more options
+       */
+      switch (op_s->op) {
+        case GTT_OP_TYPECAST:
+          gst_tensor_transform_typecast_value (filter, &value,
+              op_s->value.type);
+          break;
+        case GTT_OP_ADD:
+          gst_tensor_transform_typecast_value (filter, &op_s->value,
+              value.type);
+          gst_tensor_transform_do_operator (filter, &value, &op_s->value,
+              GTT_OP_ADD);
+          break;
+        case GTT_OP_MUL:
+          gst_tensor_transform_typecast_value (filter, &op_s->value,
+              value.type);
+          gst_tensor_transform_do_operator (filter, &value, &op_s->value,
+              GTT_OP_MUL);
+          break;
+        case GTT_OP_DIV:
+          gst_tensor_transform_typecast_value (filter, &op_s->value,
+              value.type);
+          gst_tensor_transform_do_operator (filter, &value, &op_s->value,
+              GTT_OP_DIV);
+          break;
+        default:
+          g_assert (0);
+          return GST_FLOW_ERROR;
+      }
+
+      walk = g_slist_next (walk);
+    }
+
+    /* set output value */
+    g_assert (out_tensor_type == value.type);
+    data_idx = tensor_element_size[out_tensor_type] * i;
+    gst_tensor_transform_get_value (filter, &value,
+        (gpointer) (outptr + data_idx));
   }
 
   return GST_FLOW_OK;
@@ -970,6 +1166,7 @@ gst_tensor_transform_transform (GstBaseTransform * trans,
   uint8_t *inptr, *outptr;
   GstMapInfo inInfo, outInfo;
 
+  g_assert (filter->loaded);
   g_assert (gst_buffer_map (inbuf, &inInfo, GST_MAP_READ));
   g_assert (gst_buffer_map (outbuf, &outInfo, GST_MAP_WRITE));
 
@@ -1041,12 +1238,13 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter,
     GstPadDirection direction, const GstTensorInfo * in_info,
     GstTensorInfo * out_info)
 {
+  int i;
+
   switch (filter->mode) {
     case GTT_DIMCHG:
       out_info->type = in_info->type;
 
       if (direction == GST_PAD_SINK) {
-        int i;
         int a = filter->data_dimchg.from;
         int b = filter->data_dimchg.to;
 
@@ -1075,7 +1273,6 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter,
           }
         }
       } else {
-        int i;
         int a = filter->data_dimchg.from;
         int b = filter->data_dimchg.to;
 
@@ -1105,10 +1302,9 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter,
         }
       }
       break;
+
     case GTT_TYPECAST:
-    {
-        /** For both directions, dimension does not change */
-      int i;
+      /** For both directions, dimension does not change */
       for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) {
         out_info->dimension[i] = in_info->dimension[i];
       }
@@ -1120,20 +1316,23 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter,
         out_info->type = in_info->type;   /** @todo this may cause problems with Cap-Transform */
       }
       break;
-    }
+
     case GTT_ARITHMETIC:
-    {
-      int i;
       for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) {
         out_info->dimension[i] = in_info->dimension[i];
       }
       out_info->type = in_info->type;
+
+      /* check arith mode option has typecast operator */
+      if (direction == GST_PAD_SINK &&
+          filter->data_arithmetic.out_type != _NNS_END) {
+        out_info->type = filter->data_arithmetic.out_type;
+      }
       break;
-    }
+
     case GTT_TRANSPOSE:
-    {
       out_info->type = in_info->type;
-      int i;
+
       if (direction == GST_PAD_SINK) {
         for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) {
           out_info->dimension[i] =
@@ -1148,16 +1347,14 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter,
         }
       }
       break;
-    }
+
     case GTT_STAND:
-    {
-      int i;
       for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) {
         out_info->dimension[i] = in_info->dimension[i];
       }
       out_info->type = in_info->type;
       break;
-    }
+
     default:
       return FALSE;
   }
index 6377512..90c9fd9 100644 (file)
@@ -46,8 +46,6 @@ G_BEGIN_DECLS
   (G_TYPE_CHECK_CLASS_TYPE((klass),GST_TYPE_TENSOR_TRANSFORM))
 #define GST_TENSOR_TRANSFORM_CAST(obj)  ((GstTensorTransform *)(obj))
 
-#define ARITH_OPRND_NUM_LIMIT  2
-
 typedef struct _GstTensorTransform GstTensorTransform;
 typedef struct _GstTensorTransformClass GstTensorTransformClass;
 
@@ -64,13 +62,13 @@ typedef enum
 
 typedef enum
 {
-  ARITH_ADD = 0,
-  ARITH_MUL = 1,
-  ARITH_ADD_MUL = 2,            /* Fused add-multiply */
-  ARITH_MUL_ADD = 3,            /* Fused multiply-add */
+  GTT_OP_TYPECAST = 0,
+  GTT_OP_ADD = 1,
+  GTT_OP_MUL = 2,
+  GTT_OP_DIV = 3,
 
-  ARITH_END,
-} tensor_transform_arith_mode;
+  GTT_OP_UNKNOWN
+} tensor_transform_operator;
 
 typedef enum
 {
@@ -78,14 +76,6 @@ typedef enum
   STAND_END,
 } tensor_transform_stand_mode;
 
-typedef enum
-{
-  ARITH_OPRND_TYPE_INT64 = 0,
-  ARITH_OPRND_TYPE_DOUBLE = 1,
-
-  ARITH_OPRND_TYPE_END
-} tensor_transform_arith_oprnd_type;
-
 /**
  * @brief Internal data structure for dimchg mode.
  */
@@ -104,20 +94,26 @@ typedef struct _tensor_transform_typecast {
 /**
  * @brief Internal data structure for operand of arithmetic mode.
  */
-typedef struct _tensor_transform_arithmetic_operand {
-  tensor_transform_arith_oprnd_type type;
-  union {
-    int64_t value_int64;
-    double value_double;
-  };
-} tensor_transform_arithmetic_operand;
+typedef struct
+{
+  tensor_type type;
+  tensor_element data;
+} tensor_transform_operand_s;
+
+/**
+ * @brief Internal data structure for operator of arithmetic mode.
+ */
+typedef struct
+{
+  tensor_transform_operator op;
+  tensor_transform_operand_s value;
+} tensor_transform_operator_s;
 
 /**
  * @brief Internal data structure for arithmetic mode.
  */
 typedef struct _tensor_transform_arithmetic {
-  tensor_transform_arith_mode mode;
-  tensor_transform_arithmetic_operand value[ARITH_OPRND_NUM_LIMIT];
+  tensor_type out_type;
 } tensor_transform_arithmetic;
 
 /**
@@ -153,6 +149,8 @@ struct _GstTensorTransform
   };
   gboolean loaded; /**< TRUE if mode & option are loaded */
 
+  GSList *operators; /**< operators list */
+
   GstTensorConfig in_config; /**< input tensor info */
   GstTensorConfig out_config; /**< output tensor info */
 };
index 5f0c52e..a28bce6 100644 (file)
@@ -45,16 +45,16 @@ python checkResult.py arithmetic testcase02.direct.log testcase02.arithmetic.log
 testResult $? 2 "Golden test comparison" 0 1
 
 # Test for mul with floating-point operand
-gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float32 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul:-5.5 ! filesink location=\"testcase03.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase03.direct.log\" sync=true" 3 0 0 $PERFORMANCE
+gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float32 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul:-5.5 ! filesink location=\"testcase03.arithmetic.1.log\" sync=true t. ! queue ! filesink location=\"testcase03.direct.1.log\" sync=true" 3 0 0 $PERFORMANCE
 
-python checkResult.py arithmetic testcase03.direct.log testcase03.arithmetic.log 4 4 f f mul -5.5 0
+python checkResult.py arithmetic testcase03.direct.1.log testcase03.arithmetic.1.log 4 4 f f mul -5.5 0
 testResult $? 3 "Golden test comparison" 0 1
 
-# Fail Test 3-F: for mul with floating-point operand
-gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float32 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul::-5.5 ! filesink location=\"testcase03.arithmetic.fail.log\" sync=true t. ! queue ! filesink location=\"testcase03.direct.fail.log\" sync=true" 3-F 0 1 $PERFORMANCE
+# Test 3-2 for typecast,mul with floating-point operand
+gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tee name=t ! queue ! tensor_transform mode=arithmetic option=typecast:float32,mul:-5.5 ! filesink location=\"testcase03.arithmetic.2.log\" sync=true t. ! queue ! tensor_transform mode=typecast option=float32 ! filesink location=\"testcase03.direct.2.log\" sync=true" 3-2 0 0 $PERFORMANCE
 
-python checkResult.py arithmetic testcase03.direct.fail.log testcase03.arithmetic.fail.log 4 4 f f mul 0 -5.5
-testResult $? 3-F "Golden test comparison" 0 1
+python checkResult.py arithmetic testcase03.direct.2.log testcase03.arithmetic.2.log 4 4 f f mul -5.5 0
+testResult $? 3-2 "Golden test comparison" 0 1
 
 # Test for add with floating-point operand
 gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add:9.900000e-001 ! filesink location=\"testcase04.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase04.direct.log\" sync=true" 4 0 0 $PERFORMANCE
@@ -68,32 +68,14 @@ gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequenc
 python checkResult.py arithmetic testcase04.direct.ok.log testcase04.arithmetic.ok.log 8 8 d d add 9.900000e-001 -80.256
 testResult $? 4-OK "Golden test comparison" 0 1
 
-# Test for add-mul with floating-point operands
-gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul:-9.3:-11.4823e-002 ! filesink location=\"testcase05.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.log\" sync=true" 5 0 0 $PERFORMANCE
+# Test for add,mul with floating-point operands
+gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add:-9.3,mul:-11.4823e-002 ! filesink location=\"testcase05.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.log\" sync=true" 5 0 0 $PERFORMANCE
 
 testResult $? 5 "Golden test comparison" 0 1
 python checkResult.py arithmetic testcase05.direct.log testcase05.arithmetic.log 8 8 d d add-mul -9.3 -11.4823e-002
 
-# Fail Test 5-F1: add-mul with single floating-point operand
-gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul:-9.3 ! filesink location=\"testcase05.arithmetic.fail1.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.fail1.log\" sync=true" 5-F1 0 1 $PERFORMANCE
-
-testResult $? 5-F1 "Golden test comparison" 0 1
-python checkResult.py arithmetic testcase05.direct.fail1.log testcase05.arithmetic.fail1.log 8 8 d d add-mul -9.3 0
-
-# Fail Test 5-F2: add-mul with single floating-point operand
-gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul:-9.3: ! filesink location=\"testcase05.arithmetic.fail2.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.fail2.log\" sync=true" 5-F2 0 1 $PERFORMANCE
-
-testResult $? 5-F2 "Golden test comparison" 0 1
-python checkResult.py arithmetic testcase05.direct.fail2.log testcase05.arithmetic.fail2.log 8 8 d d add-mul -9.3 0
-
-# Fail Test 5-F3: add-mul with single floating-point operand
-gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul::-11.4823e-002 ! filesink location=\"testcase05.arithmetic.fail3.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.fail3.log\" sync=true" 5-F3 0 1 $PERFORMANCE
-
-testResult $? 5-F3 "Golden test comparison" 0 1
-python checkResult.py arithmetic testcase05.direct.fail3.log testcase05.arithmetic.fail3.log 8 8 d d add-mul 30 -11.4823e-002
-
-# Test for mul-add with floating-point operands
-gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul-add:-50.0987e+003:15.3 ! filesink location=\"testcase06.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase06.direct.log\" sync=true" 6 0 0 $PERFORMANCE
+# Test for mul,add with floating-point operands
+gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul:-50.0987e+003,add:15.3 ! filesink location=\"testcase06.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase06.direct.log\" sync=true" 6 0 0 $PERFORMANCE
 
 testResult $? 6 "Golden test comparison" 0 1
 python checkResult.py arithmetic testcase06.direct.log testcase06.arithmetic.log 8 8 d d add-mul -50.0987e+003 15.3