From: Wook Song Date: Thu, 11 Oct 2018 06:53:55 +0000 (+0900) Subject: [Transform/Arithmetic] Support fused add-mul and mul-add operators X-Git-Tag: v0.0.2~14 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d6b92b91527a5ac35103d70760b6e69cfa079911;p=platform%2Fupstream%2Fnnstreamer.git [Transform/Arithmetic] Support fused add-mul and mul-add operators This patch adds two operators, fused add-mul and fused mul-add, to the Arithmetic mode of the transform plugin. Signed-off-by: Wook Song --- diff --git a/gst/tensor_transform/tensor_transform.c b/gst/tensor_transform/tensor_transform.c index 196d524..9a5f48e 100644 --- a/gst/tensor_transform/tensor_transform.c +++ b/gst/tensor_transform/tensor_transform.c @@ -197,11 +197,11 @@ static const gchar *gst_tensor_transform_mode_string[] = { [GTT_END] = "error", }; - /*TODO*/ -/* [ARITH_MAD] = "mad", (pixel[] + a) * b should be supported */ static const gchar *gst_tensor_transform_arithmetic_string[] = { [ARITH_ADD] = "add", [ARITH_MUL] = "mul", + [ARITH_ADD_MUL] = "add-mul", + [ARITH_MUL_ADD] = "mul-add", [ARITH_END] = "error", }; @@ -288,7 +288,6 @@ gst_tensor_transform_set_option_data (GstTensor_Transform * filter) case GTT_ARITHMETIC: { gchar **strv = g_strsplit (filter->option, ":", 2); - gchar *not_consumed; if (strv[0] != NULL) { filter->data_arithmetic.mode = @@ -296,22 +295,33 @@ gst_tensor_transform_set_option_data (GstTensor_Transform * filter) } if (strv[1] != NULL) { - if (strchr (strv[1], '.') || strchr (strv[1], 'e') || - strchr (strv[1], 'E')) { - filter->data_arithmetic.value.type = ARITH_OPRND_TYPE_DOUBLE; - filter->data_arithmetic.value.value_double = - g_ascii_strtod (strv[1], ¬_consumed); - } else { - filter->data_arithmetic.value.type = ARITH_OPRND_TYPE_INT64; - filter->data_arithmetic.value.value_int64 = - g_ascii_strtoll (strv[1], ¬_consumed, 10); - } + gchar **operands = g_strsplit (strv[1], ":", 2); + gchar *not_consumed; + int i; + + for (i = 0; i < ARITH_OPRND_NUM_LIMIT; ++i) { + filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_END; + if ((operands[i] != NULL) && (strlen (operands[i]) != 0)) { + if (strchr (operands[i], '.') || strchr (operands[i], 'e') || + strchr (operands[i], 'E')) { + filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_DOUBLE; + filter->data_arithmetic.value[i].value_double = + g_ascii_strtod (operands[i], ¬_consumed); + } else { + filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_INT64; + filter->data_arithmetic.value[i].value_int64 = + g_ascii_strtoll (operands[i], ¬_consumed, 10); + } - if (strlen (not_consumed)) { - g_printerr ("%s is not a valid integer or floating point value\n", - strv[1]); - g_assert (0); + if (strlen (not_consumed)) { + g_printerr ("%s is not a valid integer or floating point value\n", + operands[i]); + g_assert (0); + } + } } + + g_strfreev (operands); } filter->loaded = TRUE; @@ -633,7 +643,7 @@ gst_tensor_transform_typecast (GstTensor_Transform * filter, } /** - * Macro to run loop for various data types with simple arithmetic + * Macro to run loop for various data types with simple arithmetic which has single operand */ #define arith(itype, num, op, a) do { \ size_t i; \ @@ -644,6 +654,59 @@ gst_tensor_transform_typecast (GstTensor_Transform * filter, } \ }while(0); +/** + * Macro to run loop for various data types with simple arithmetic which has dual operands + */ +#define arith2(itype, num, op1, a, op2, b) do { \ + size_t i; \ + itype *in = (itype *) inptr; \ + itype *out = (itype *) outptr; \ + for (i=0;idata_arithmetic.mode]);\ + g_assert(0); \ + }; \ + arith(itype, num, op, a); break; \ +} while (0); \ + +/** + * Macro to handle the case of dual operands + */ +#define arithmode_dual_oprnd_case(itype, num, mode, op1, op2, value) \ +do {\ + itype a;\ + itype b; \ + switch (value[0].type) {\ + case ARITH_OPRND_TYPE_INT64 : a = (itype) value[0].value_int64; break; \ + case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value[0].value_double; break;\ + default: \ + g_printerr ("The operands required by \'%s\' are not properly provided.\n", \ + gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\ + g_assert(0); \ + }; \ + switch (value[1].type) {\ + case ARITH_OPRND_TYPE_INT64 : b = (itype) value[1].value_int64; break; \ + case ARITH_OPRND_TYPE_DOUBLE : b = (itype) value[1].value_double; break;\ + default: \ + g_printerr ("The operands required by \'%s\' are not properly provided.\n", \ + gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\ + g_assert(0); \ + }; \ + arith2(itype, num, op1, a, op2, b); break; \ +} while (0); \ /** * Macro to run loop for various data types with simple arithmetic @@ -651,19 +714,27 @@ gst_tensor_transform_typecast (GstTensor_Transform * filter, #define arithloopcase(typecase, itype, num, mode, value) \ case typecase: \ { \ - itype a; \ - switch (value.type) {\ - case ARITH_OPRND_TYPE_INT64 : a = (itype) value.value_int64; break; \ - case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value.value_double; break;\ - default: g_assert(0); \ - }; \ switch (mode) { \ - case ARITH_ADD : arith(itype, num, +, a); break; \ - case ARITH_MUL : arith(itype, num, *, a); break; \ + case ARITH_ADD: {\ + arithmode_single_oprnd_case (itype, num, mode, +, value); \ + break; \ + }; \ + case ARITH_MUL: { \ + arithmode_single_oprnd_case (itype, num, mode, *, value); \ + break; \ + };\ + case ARITH_ADD_MUL: {\ + arithmode_dual_oprnd_case (itype, num, mode, +, *, value); \ + break; \ + }; \ + case ARITH_MUL_ADD: {\ + arithmode_dual_oprnd_case (itype, num, mode, *, +, value); \ + break; \ + }; \ default: g_assert(0); \ - }; \ - }; \ - break; \ + } \ + break; \ + }; \ /** * @brief subrouting for tensor-tranform, "arithmetic" case. @@ -678,7 +749,7 @@ gst_tensor_transform_arithmetic (GstTensor_Transform * filter, { uint32_t num = get_tensor_element_count (filter->fromDim); tensor_transform_arith_mode mode = filter->data_arithmetic.mode; - tensor_transform_arithmetic_operand value = filter->data_arithmetic.value; + tensor_transform_arithmetic_operand *value = filter->data_arithmetic.value; switch (filter->type) { arithloopcase (_NNS_INT8, int8_t, num, mode, value); diff --git a/gst/tensor_transform/tensor_transform.h b/gst/tensor_transform/tensor_transform.h index 8fd964e..0620cfa 100644 --- a/gst/tensor_transform/tensor_transform.h +++ b/gst/tensor_transform/tensor_transform.h @@ -60,6 +60,7 @@ G_BEGIN_DECLS #define GST_IS_TENSOR_TRANSFORM_CLASS(klass) \ (G_TYPE_CHECK_CLASS_TYPE((klass),GST_TYPE_TENSOR_TRANSFORM)) #define GST_TENSOR_TRANSFORM_CAST(obj) ((GstTensor_Transform *)(obj)) +#define ARITH_OPRND_NUM_LIMIT 2 typedef struct _GstTensor_Transform GstTensor_Transform; @@ -80,6 +81,9 @@ typedef enum { ARITH_ADD = 0, ARITH_MUL = 1, + ARITH_ADD_MUL = 2, /* Fused add-multiply */ + ARITH_MUL_ADD = 3, /* Fused multiply-add */ + ARITH_END, } tensor_transform_arith_mode; @@ -128,7 +132,7 @@ typedef struct _tensor_transform_arithmetic_operand { */ typedef struct _tensor_transform_arithmetic { tensor_transform_arith_mode mode; - tensor_transform_arithmetic_operand value; + tensor_transform_arithmetic_operand value[ARITH_OPRND_NUM_LIMIT]; } tensor_transform_arithmetic; /**