[Transform/Arithmetic] Support fused add-mul and mul-add operators
authorWook Song <wook16.song@samsung.com>
Thu, 11 Oct 2018 06:53:55 +0000 (15:53 +0900)
committerjaeyun-jung <39614140+jaeyun-jung@users.noreply.github.com>
Fri, 12 Oct 2018 10:55:28 +0000 (19:55 +0900)
This patch adds two operators, fused add-mul and fused mul-add, to the
Arithmetic mode of the transform plugin.

Signed-off-by: Wook Song <wook16.song@samsung.com>
gst/tensor_transform/tensor_transform.c
gst/tensor_transform/tensor_transform.h

index 196d524..9a5f48e 100644 (file)
@@ -197,11 +197,11 @@ static const gchar *gst_tensor_transform_mode_string[] = {
   [GTT_END] = "error",
 };
 
- /*TODO*/
-/* [ARITH_MAD] = "mad", (pixel[] + a) * b should be supported */
 static const gchar *gst_tensor_transform_arithmetic_string[] = {
   [ARITH_ADD] = "add",
   [ARITH_MUL] = "mul",
+  [ARITH_ADD_MUL] = "add-mul",
+  [ARITH_MUL_ADD] = "mul-add",
   [ARITH_END] = "error",
 };
 
@@ -288,7 +288,6 @@ gst_tensor_transform_set_option_data (GstTensor_Transform * filter)
     case GTT_ARITHMETIC:
     {
       gchar **strv = g_strsplit (filter->option, ":", 2);
-      gchar *not_consumed;
 
       if (strv[0] != NULL) {
         filter->data_arithmetic.mode =
@@ -296,22 +295,33 @@ gst_tensor_transform_set_option_data (GstTensor_Transform * filter)
       }
 
       if (strv[1] != NULL) {
-        if (strchr (strv[1], '.') || strchr (strv[1], 'e') ||
-            strchr (strv[1], 'E')) {
-          filter->data_arithmetic.value.type = ARITH_OPRND_TYPE_DOUBLE;
-          filter->data_arithmetic.value.value_double =
-              g_ascii_strtod (strv[1], &not_consumed);
-        } else {
-          filter->data_arithmetic.value.type = ARITH_OPRND_TYPE_INT64;
-          filter->data_arithmetic.value.value_int64 =
-              g_ascii_strtoll (strv[1], &not_consumed, 10);
-        }
+        gchar **operands = g_strsplit (strv[1], ":", 2);
+        gchar *not_consumed;
+        int i;
+
+        for (i = 0; i < ARITH_OPRND_NUM_LIMIT; ++i) {
+          filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_END;
+          if ((operands[i] != NULL) && (strlen (operands[i]) != 0)) {
+            if (strchr (operands[i], '.') || strchr (operands[i], 'e') ||
+                strchr (operands[i], 'E')) {
+              filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_DOUBLE;
+              filter->data_arithmetic.value[i].value_double =
+                  g_ascii_strtod (operands[i], &not_consumed);
+            } else {
+              filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_INT64;
+              filter->data_arithmetic.value[i].value_int64 =
+                  g_ascii_strtoll (operands[i], &not_consumed, 10);
+            }
 
-        if (strlen (not_consumed)) {
-          g_printerr ("%s is not a valid integer or floating point value\n",
-              strv[1]);
-          g_assert (0);
+            if (strlen (not_consumed)) {
+              g_printerr ("%s is not a valid integer or floating point value\n",
+                  operands[i]);
+              g_assert (0);
+            }
+          }
         }
+
+        g_strfreev (operands);
       }
 
       filter->loaded = TRUE;
@@ -633,7 +643,7 @@ gst_tensor_transform_typecast (GstTensor_Transform * filter,
 }
 
 /**
- * Macro to run loop for various data types with simple arithmetic
+ * Macro to run loop for various data types with simple arithmetic which has single operand
  */
 #define arith(itype, num, op, a) do { \
     size_t i; \
@@ -644,6 +654,59 @@ gst_tensor_transform_typecast (GstTensor_Transform * filter,
     } \
   }while(0);
 
+/**
+ * Macro to run loop for various data types with simple arithmetic which has dual operands
+ */
+#define arith2(itype, num, op1, a, op2, b) do { \
+    size_t i; \
+    itype *in = (itype *) inptr; \
+    itype *out = (itype *) outptr; \
+    for (i=0;i<num;i++){ \
+      *(out+i) = (*(in+i) op1 a) op2 b; \
+    } \
+  }while(0);
+
+/**
+ * Macro to handle the case of single operand
+ */
+#define arithmode_single_oprnd_case(itype, num, mode, op, value) do {\
+  itype a;\
+  switch (value[0].type) {\
+  case ARITH_OPRND_TYPE_INT64 : a = (itype) value[0].value_int64; break; \
+  case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value[0].value_double; break;\
+  default: \
+  g_printerr ("The operand required by \'%s\' is not properly provided.\n", \
+      gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\
+  g_assert(0); \
+  }; \
+  arith(itype, num, op, a); break; \
+} while (0); \
+
+/**
+ * Macro to handle the case of dual operands
+ */
+#define arithmode_dual_oprnd_case(itype, num, mode, op1, op2, value) \
+do {\
+  itype a;\
+  itype b; \
+  switch (value[0].type) {\
+  case ARITH_OPRND_TYPE_INT64 : a = (itype) value[0].value_int64; break; \
+  case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value[0].value_double; break;\
+  default: \
+  g_printerr ("The operands required by \'%s\' are not properly provided.\n", \
+      gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\
+  g_assert(0); \
+  }; \
+  switch (value[1].type) {\
+  case ARITH_OPRND_TYPE_INT64 : b = (itype) value[1].value_int64; break; \
+  case ARITH_OPRND_TYPE_DOUBLE : b = (itype) value[1].value_double; break;\
+  default: \
+  g_printerr ("The operands required by \'%s\' are not properly provided.\n", \
+      gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\
+  g_assert(0); \
+  }; \
+  arith2(itype, num, op1, a, op2, b); break; \
+} while (0); \
 
 /**
  * Macro to run loop for various data types with simple arithmetic
@@ -651,19 +714,27 @@ gst_tensor_transform_typecast (GstTensor_Transform * filter,
 #define arithloopcase(typecase, itype, num, mode, value) \
   case typecase: \
   { \
-    itype a; \
-    switch (value.type) {\
-    case ARITH_OPRND_TYPE_INT64 : a = (itype) value.value_int64; break; \
-    case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value.value_double; break;\
-    default: g_assert(0); \
-    }; \
     switch (mode) { \
-    case ARITH_ADD : arith(itype, num, +, a); break; \
-    case ARITH_MUL : arith(itype, num, *, a); break; \
+    case ARITH_ADD: {\
+      arithmode_single_oprnd_case (itype, num, mode, +, value); \
+      break; \
+    }; \
+    case ARITH_MUL: { \
+      arithmode_single_oprnd_case (itype, num, mode, *, value); \
+      break; \
+    };\
+    case ARITH_ADD_MUL: {\
+      arithmode_dual_oprnd_case (itype, num, mode, +, *, value); \
+      break; \
+    }; \
+    case ARITH_MUL_ADD: {\
+      arithmode_dual_oprnd_case (itype, num, mode, *, +, value); \
+      break; \
+    }; \
     default: g_assert(0); \
-    };                    \
-  };                      \
-  break; \
+    } \
+    break; \
+  }; \
 
 /**
  * @brief subrouting for tensor-tranform, "arithmetic" case.
@@ -678,7 +749,7 @@ gst_tensor_transform_arithmetic (GstTensor_Transform * filter,
 {
   uint32_t num = get_tensor_element_count (filter->fromDim);
   tensor_transform_arith_mode mode = filter->data_arithmetic.mode;
-  tensor_transform_arithmetic_operand value = filter->data_arithmetic.value;
+  tensor_transform_arithmetic_operand *value = filter->data_arithmetic.value;
 
   switch (filter->type) {
       arithloopcase (_NNS_INT8, int8_t, num, mode, value);
index 8fd964e..0620cfa 100644 (file)
@@ -60,6 +60,7 @@ G_BEGIN_DECLS
 #define GST_IS_TENSOR_TRANSFORM_CLASS(klass) \
   (G_TYPE_CHECK_CLASS_TYPE((klass),GST_TYPE_TENSOR_TRANSFORM))
 #define GST_TENSOR_TRANSFORM_CAST(obj)  ((GstTensor_Transform *)(obj))
+#define ARITH_OPRND_NUM_LIMIT  2
 
 typedef struct _GstTensor_Transform GstTensor_Transform;
 
@@ -80,6 +81,9 @@ typedef enum
 {
   ARITH_ADD = 0,
   ARITH_MUL = 1,
+  ARITH_ADD_MUL = 2,            /* Fused add-multiply */
+  ARITH_MUL_ADD = 3,            /* Fused multiply-add */
+
   ARITH_END,
 } tensor_transform_arith_mode;
 
@@ -128,7 +132,7 @@ typedef struct _tensor_transform_arithmetic_operand {
  */
 typedef struct _tensor_transform_arithmetic {
   tensor_transform_arith_mode mode;
-  tensor_transform_arithmetic_operand value;
+  tensor_transform_arithmetic_operand value[ARITH_OPRND_NUM_LIMIT];
 } tensor_transform_arithmetic;
 
 /**