From 5ea6ec6c4e8cafb690f3a7b33fe77b4da8b11354 Mon Sep 17 00:00:00 2001 From: Jaeyun Date: Wed, 14 Nov 2018 18:31:18 +0900 Subject: [PATCH] [Transform] refactor arith/typecast refactor arith/typecast to handle arithmetic in sequence 1. add operands and handle tensor element with parsed option 2. add common functions for the operators 3. remove multi-op (add-mul and mul-add) and handle operator in sequence Signed-off-by: Jaeyun Jung --- gst/tensor_transform/tensor_transform.c | 797 ++++++++++++++++++++------------ gst/tensor_transform/tensor_transform.h | 48 +- tests/transform_arithmetic/runTest.sh | 38 +- 3 files changed, 530 insertions(+), 353 deletions(-) diff --git a/gst/tensor_transform/tensor_transform.c b/gst/tensor_transform/tensor_transform.c index 22fe903..4150d41 100644 --- a/gst/tensor_transform/tensor_transform.c +++ b/gst/tensor_transform/tensor_transform.c @@ -109,12 +109,12 @@ static const gchar *gst_tensor_transform_stand_string[] = { [STAND_END] = "error" }; -static const gchar *gst_tensor_transform_arithmetic_string[] = { - [ARITH_ADD] = "add", - [ARITH_MUL] = "mul", - [ARITH_ADD_MUL] = "add-mul", - [ARITH_MUL_ADD] = "mul-add", - [ARITH_END] = "error" +static const gchar *gst_tensor_transform_operator_string[] = { + [GTT_OP_TYPECAST] = "typecast", + [GTT_OP_ADD] = "add", + [GTT_OP_MUL] = "mul", + [GTT_OP_DIV] = "div", + [GTT_OP_UNKNOWN] = "unknown" }; /** @@ -226,24 +226,25 @@ gst_tensor_transform_init (GstTensorTransform * filter) filter->mode = GTT_END; filter->option = NULL; filter->loaded = FALSE; + filter->operators = NULL; gst_tensor_config_init (&filter->in_config); gst_tensor_config_init (&filter->out_config); } /** - * @brief Get the corresponding mode from the string value - * @param[in] str The string value for the mode - * @return corresponding mode for the string. ARITH_END for errors + * @brief Get the corresponding operator from the string value + * @param[in] str The string value for the operator + * @return corresponding operator for the string (GTT_OP_UNKNOWN for errors) */ -static tensor_transform_arith_mode -gst_tensor_transform_get_arith_mode (const gchar * str) +static tensor_transform_operator +gst_tensor_transform_get_operator (const gchar * str) { int index; - index = find_key_strv (gst_tensor_transform_arithmetic_string, str); + index = find_key_strv (gst_tensor_transform_operator_string, str); - return (index < 0) ? ARITH_END : index; + return (index < 0) ? GTT_OP_UNKNOWN : index; } /** @@ -277,6 +278,315 @@ gst_tensor_transform_get_mode (const gchar * str) } /** + * @brief Macro to set operand + */ +#define set_operand_value(v,d,vtype) do { \ + (v)->data._##vtype = *((vtype *) d); \ + } while (0) + +/** + * @brief Set tensor element value with given type + * @param filter "this" pointer + * @param value struct for operand of arith mode + * @param type tensor type + * @param data pointer of tensor element value + * @return TRUE if no error + */ +static gboolean +gst_tensor_transform_set_value (GstTensorTransform * filter, + tensor_transform_operand_s * value, tensor_type type, gpointer data) +{ + g_return_val_if_fail (value != NULL, FALSE); + g_return_val_if_fail (data != NULL, FALSE); + + /* init tensor value */ + memset (value, 0, sizeof (tensor_transform_operand_s)); + value->type = _NNS_END; + + switch (type) { + case _NNS_INT32: + set_operand_value (value, data, int32_t); + break; + case _NNS_UINT32: + set_operand_value (value, data, uint32_t); + break; + case _NNS_INT16: + set_operand_value (value, data, int16_t); + break; + case _NNS_UINT16: + set_operand_value (value, data, uint16_t); + break; + case _NNS_INT8: + set_operand_value (value, data, int8_t); + break; + case _NNS_UINT8: + set_operand_value (value, data, uint8_t); + break; + case _NNS_FLOAT64: + set_operand_value (value, data, double); + break; + case _NNS_FLOAT32: + set_operand_value (value, data, float); + break; + case _NNS_INT64: + set_operand_value (value, data, int64_t); + break; + case _NNS_UINT64: + set_operand_value (value, data, uint64_t); + break; + default: + GST_ERROR_OBJECT (filter, "Unknown tensor type %d", type); + return FALSE; + } + + value->type = type; + return TRUE; +} + +/** + * @brief Macro to get operand + */ +#define get_operand_value(v,d,vtype) do { \ + *((vtype *) d) = (v)->data._##vtype; \ + } while (0) + +/** + * @brief Get tensor element value with given type + * @param filter "this" pointer + * @param value struct for operand of arith mode + * @param data pointer of tensor element value + * @return TRUE if no error + */ +static gboolean +gst_tensor_transform_get_value (GstTensorTransform * filter, + tensor_transform_operand_s * value, gpointer data) +{ + g_return_val_if_fail (value != NULL, FALSE); + g_return_val_if_fail (data != NULL, FALSE); + + switch (value->type) { + case _NNS_INT32: + get_operand_value (value, data, int32_t); + break; + case _NNS_UINT32: + get_operand_value (value, data, uint32_t); + break; + case _NNS_INT16: + get_operand_value (value, data, int16_t); + break; + case _NNS_UINT16: + get_operand_value (value, data, uint16_t); + break; + case _NNS_INT8: + get_operand_value (value, data, int8_t); + break; + case _NNS_UINT8: + get_operand_value (value, data, uint8_t); + break; + case _NNS_FLOAT64: + get_operand_value (value, data, double); + break; + case _NNS_FLOAT32: + get_operand_value (value, data, float); + break; + case _NNS_INT64: + get_operand_value (value, data, int64_t); + break; + case _NNS_UINT64: + get_operand_value (value, data, uint64_t); + break; + default: + GST_ERROR_OBJECT (filter, "Unknown tensor type %d", value->type); + return FALSE; + } + + return TRUE; +} + +/** + * @brief Macro for operator + */ +#define handle_operator(d,v,oper,vtype) do { \ + switch (oper) { \ + case GTT_OP_ADD: \ + (d)->data._##vtype += (v)->data._##vtype; \ + break; \ + case GTT_OP_MUL: \ + (d)->data._##vtype *= (v)->data._##vtype; \ + break; \ + case GTT_OP_DIV: \ + if ((v)->data._##vtype == 0) { \ + GST_ERROR_OBJECT (filter, "Invalid state, denominator is 0."); \ + return FALSE; \ + } \ + (d)->data._##vtype /= (v)->data._##vtype; \ + break; \ + default: \ + GST_ERROR_OBJECT (filter, "Unknown operator %d", oper); \ + return FALSE; \ + } \ + } while (0) + +/** + * @brief Handle operators for tensor value + * @param filter "this" pointer + * @param desc struct for tensor value + * @param val struct for tensor value + * @param op operator for given tensor value + * @return TRUE if no error + */ +static gboolean +gst_tensor_transform_do_operator (GstTensorTransform * filter, + tensor_transform_operand_s * desc, const tensor_transform_operand_s * val, + tensor_transform_operator op) +{ + g_return_val_if_fail (desc != NULL, FALSE); + g_return_val_if_fail (val != NULL, FALSE); + g_return_val_if_fail (desc->type == val->type, FALSE); + + switch (desc->type) { + case _NNS_INT32: + handle_operator (desc, val, op, int32_t); + break; + case _NNS_UINT32: + handle_operator (desc, val, op, uint32_t); + break; + case _NNS_INT16: + handle_operator (desc, val, op, int16_t); + break; + case _NNS_UINT16: + handle_operator (desc, val, op, uint16_t); + break; + case _NNS_INT8: + handle_operator (desc, val, op, int8_t); + break; + case _NNS_UINT8: + handle_operator (desc, val, op, uint8_t); + break; + case _NNS_FLOAT64: + handle_operator (desc, val, op, double); + break; + case _NNS_FLOAT32: + handle_operator (desc, val, op, float); + break; + case _NNS_INT64: + handle_operator (desc, val, op, int64_t); + break; + case _NNS_UINT64: + handle_operator (desc, val, op, uint64_t); + break; + default: + GST_ERROR_OBJECT (filter, "Unknown tensor type %d", desc->type); + return FALSE; + } + + return TRUE; +} + +/** + * @brief Macro for typecast + */ +#define typecast_value_to(v,itype,otype) do { \ + itype in_val = (v)->data._##itype; \ + otype out_val = (otype) in_val; \ + (v)->data._##otype = out_val; \ + } while (0) + +#define typecast_value(v,otype) do { \ + switch ((v)->type) { \ + case _NNS_INT32: typecast_value_to (v, int32_t, otype); break; \ + case _NNS_UINT32: typecast_value_to (v, uint32_t, otype); break; \ + case _NNS_INT16: typecast_value_to (v, int16_t, otype); break; \ + case _NNS_UINT16: typecast_value_to (v, uint16_t, otype); break; \ + case _NNS_INT8: typecast_value_to (v, int8_t, otype); break; \ + case _NNS_UINT8: typecast_value_to (v, uint8_t, otype); break; \ + case _NNS_FLOAT64: typecast_value_to (v, double, otype); break; \ + case _NNS_FLOAT32: typecast_value_to (v, float, otype); break; \ + case _NNS_INT64: typecast_value_to (v, int64_t, otype); break; \ + case _NNS_UINT64: typecast_value_to (v, uint64_t, otype); break; \ + default: g_assert (0); break; \ + } \ + } while (0) + +/** + * @brief Typecast tensor element value + * @param filter "this" pointer + * @param value struct for operand of arith mode + * @param type tensor type to be transformed + * @return TRUE if no error + */ +static gboolean +gst_tensor_transform_typecast_value (GstTensorTransform * filter, + tensor_transform_operand_s * value, tensor_type type) +{ + gboolean is_float; + + g_return_val_if_fail (value != NULL, FALSE); + g_return_val_if_fail (type != _NNS_END, FALSE); + + /* do nothing when transform to same type */ + if (value->type != type) { + is_float = (type == _NNS_FLOAT32 || type == _NNS_FLOAT64); + + switch (type) { + case _NNS_INT32: + typecast_value (value, int32_t); + break; + case _NNS_UINT32: + if (is_float) { + /* int32 -> uint32 */ + typecast_value (value, int32_t); + } + typecast_value (value, uint32_t); + break; + case _NNS_INT16: + typecast_value (value, int16_t); + break; + case _NNS_UINT16: + if (is_float) { + /* int16 -> uint16 */ + typecast_value (value, int16_t); + } + typecast_value (value, uint16_t); + break; + case _NNS_INT8: + typecast_value (value, int8_t); + break; + case _NNS_UINT8: + if (is_float) { + /* int8 -> uint8 */ + typecast_value (value, int8_t); + } + typecast_value (value, uint8_t); + break; + case _NNS_FLOAT64: + typecast_value (value, double); + break; + case _NNS_FLOAT32: + typecast_value (value, float); + break; + case _NNS_INT64: + typecast_value (value, int64_t); + break; + case _NNS_UINT64: + if (is_float) { + /* int64 -> uint64 */ + typecast_value (value, int64_t); + } + typecast_value (value, uint64_t); + break; + default: + GST_ERROR_OBJECT (filter, "Unknown tensor type %d", type); + return FALSE; + } + + value->type = type; + } + + return TRUE; +} + +/** * @brief Setup internal data (data_* in GstTensorTransform) * @param[in/out] filter "this" pointer. mode & option MUST BE set already. */ @@ -317,46 +627,89 @@ gst_tensor_transform_set_option_data (GstTensorTransform * filter) } case GTT_ARITHMETIC: { - gchar **strv = g_strsplit (filter->option, ":", 2); - - if (strv[0] != NULL) { - filter->data_arithmetic.mode = - gst_tensor_transform_get_arith_mode (strv[0]); - g_assert (filter->data_arithmetic.mode != ARITH_END); - } - - if (strv[1] != NULL) { - gchar **operands = g_strsplit (strv[1], ":", 2); - gchar *not_consumed; - int i; - - for (i = 0; i < ARITH_OPRND_NUM_LIMIT; ++i) { - filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_END; - if ((operands[i] != NULL) && (strlen (operands[i]) != 0)) { - if (strchr (operands[i], '.') || strchr (operands[i], 'e') || - strchr (operands[i], 'E')) { - filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_DOUBLE; - filter->data_arithmetic.value[i].value_double = - g_ascii_strtod (operands[i], ¬_consumed); - } else { - filter->data_arithmetic.value[i].type = ARITH_OPRND_TYPE_INT64; - filter->data_arithmetic.value[i].value_int64 = - g_ascii_strtoll (operands[i], ¬_consumed, 10); - } + gchar **str_operators; + gchar **str_op; + tensor_transform_operator_s *op_s; + guint i, num_operators, num_op; + + filter->data_arithmetic.out_type = _NNS_END; + + str_operators = g_strsplit (filter->option, ",", -1); + num_operators = g_strv_length (str_operators); + + for (i = 0; i < num_operators; ++i) { + str_op = g_strsplit (str_operators[i], ":", -1); + num_op = g_strv_length (str_op); + + if (str_op[0]) { + op_s = g_new0 (tensor_transform_operator_s, 1); + g_assert (op_s); + + op_s->op = gst_tensor_transform_get_operator (str_op[0]); + + switch (op_s->op) { + case GTT_OP_TYPECAST: + if (num_op > 1 && str_op[1]) { + op_s->value.type = get_tensor_type (str_op[1]); + + if (op_s->value.type == _NNS_END) { + GST_WARNING_OBJECT (filter, "Unknown tensor type %s", + str_op[1]); + op_s->op = GTT_OP_UNKNOWN; + } else { + filter->data_arithmetic.out_type = op_s->value.type; + } + } else { + GST_WARNING_OBJECT (filter, "Invalid option for typecast %s", + str_operators[i]); + op_s->op = GTT_OP_UNKNOWN; + } + break; + case GTT_OP_ADD: + case GTT_OP_MUL: + case GTT_OP_DIV: + if (num_op > 1 && str_op[1]) { + /* get operand */ + if (strchr (str_op[1], '.') || strchr (str_op[1], 'e') || + strchr (str_op[1], 'E')) { + double val; + + val = g_ascii_strtod (str_op[1], NULL); + gst_tensor_transform_set_value (filter, &op_s->value, + _NNS_FLOAT64, &val); + } else { + int64_t val; + + val = g_ascii_strtoll (str_op[1], NULL, 10); + gst_tensor_transform_set_value (filter, &op_s->value, + _NNS_INT64, &val); + } + } else { + GST_WARNING_OBJECT (filter, "Invalid option for arithmetic %s", + str_operators[i]); + op_s->op = GTT_OP_UNKNOWN; + } + break; + default: + GST_WARNING_OBJECT (filter, "Unknown operator %s", str_op[0]); + break; + } - if (strlen (not_consumed)) { - g_printerr ("%s is not a valid integer or floating point value\n", - operands[i]); - g_assert (0); - } + /* append operator */ + if (op_s->op != GTT_OP_UNKNOWN) { + filter->operators = g_slist_append (filter->operators, op_s); + } else { + g_free (op_s); } + } else { + GST_WARNING_OBJECT (filter, "Invalid option %s", str_operators[i]); } - g_strfreev (operands); + g_strfreev (str_op); } - filter->loaded = TRUE; - g_strfreev (strv); + filter->loaded = (filter->operators != NULL); + g_strfreev (str_operators); break; } case GTT_TRANSPOSE: @@ -463,6 +816,11 @@ gst_tensor_transform_finalize (GObject * object) filter->option = NULL; } + if (filter->operators) { + g_slist_free_full (filter->operators, g_free); + filter->operators = NULL; + } + G_OBJECT_CLASS (parent_class)->finalize (object); } @@ -545,77 +903,6 @@ gst_tensor_transform_dimchg (GstTensorTransform * filter, } /** - * Macro to run loop for various data types with simple cast - */ -#define castloop(itype,otype,num) do { \ - otype *ptr = (otype *) outptr; \ - itype *iptr = (itype *) inptr; \ - size_t i; \ - for (i = 0; i < num; i++) { \ - *(ptr + i) = (otype) *(iptr + i); \ - } \ - } while (0) - -/** - * Macro to run loop for various data types with simple cast - * While castloop directly casts itype to otype, this macro indirectly casts - * itype to otype using mtype as an intermediate - */ -#define castloop_via_intermediate(itype, mtype, otype, num) do { \ - otype *ptr = (otype *) outptr; \ - itype *iptr = (itype *) inptr; \ - size_t i; \ - for (i = 0; i < num; i++) { \ - mtype m = (mtype) *(iptr + i);\ - *(ptr + i) = (otype) m; \ - } \ - } while (0) - -/** - * Macro to run loop for various data types with a converter function - */ -#define convloop(itype,otype,num,convfunc) do { \ - otype *ptr = (otype *) outptr; \ - itype *iptr = (itype *) inptr; \ - size_t i; \ - for (i = 0; i < num; i++) { \ - *(ptr + i) = convfunc(iptr + i); \ - } \ - } while (0) - -/** - * Macro to unburden switch cases with castloop/convloop (per itype) - * This is for cases otype is numeral. - */ -#define numotype_castloop_per_itype(otype,num) do { \ - switch (in_tensor_type) { \ - case _NNS_INT8: castloop(int8_t, otype, num); break; \ - case _NNS_INT16: castloop(int16_t, otype, num); break; \ - case _NNS_INT32: castloop(int32_t, otype, num); break; \ - case _NNS_UINT8: castloop(uint8_t, otype, num); break; \ - case _NNS_UINT16: castloop(uint16_t, otype, num); break; \ - case _NNS_UINT32: castloop(uint32_t, otype, num); break; \ - case _NNS_FLOAT32: castloop(float, otype, num); break; \ - case _NNS_FLOAT64: castloop(double, otype, num); break; \ - case _NNS_INT64: castloop(int64_t, otype, num); break; \ - case _NNS_UINT64: castloop(uint64_t, otype, num); break; \ - default: g_assert(0); return GST_FLOW_ERROR; \ - } \ - } while (0) - -#define numotype_castloop_via_intermediate_for_float_itype(mtype, otype, num) do { \ - switch (in_tensor_type) { \ - case _NNS_FLOAT32:\ - castloop_via_intermediate(float, mtype, otype, num); \ - break; \ - case _NNS_FLOAT64: \ - castloop_via_intermediate(double, mtype, otype, num); \ - break; \ - default: g_assert(0); \ - } \ - } while (0) - -/** * @brief subrouting for tensor-tranform, "typecast" case. * @param[in/out] filter "this" pointer * @param[in] inptr input tensor @@ -626,162 +913,32 @@ static GstFlowReturn gst_tensor_transform_typecast (GstTensorTransform * filter, const uint8_t * inptr, uint8_t * outptr) { - uint32_t num = get_tensor_element_count (filter->in_config.info.dimension); + size_t num = get_tensor_element_count (filter->in_config.info.dimension); tensor_type in_tensor_type = filter->in_config.info.type; + tensor_type out_tensor_type = filter->out_config.info.type; - switch (filter->data_typecast.to) { - case _NNS_INT8: - numotype_castloop_per_itype (int8_t, num); - break; - case _NNS_INT16: - numotype_castloop_per_itype (int16_t, num); - break; - case _NNS_INT32: - numotype_castloop_per_itype (int32_t, num); - break; - case _NNS_UINT8: - if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) { - numotype_castloop_via_intermediate_for_float_itype (int8_t, uint8_t, - num); - } else { - numotype_castloop_per_itype (uint8_t, num); - } - break; - case _NNS_UINT16: - if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) { - numotype_castloop_via_intermediate_for_float_itype (int16_t, uint16_t, - num); - } else { - numotype_castloop_per_itype (uint16_t, num); - } - break; - case _NNS_UINT32: - if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) { - numotype_castloop_via_intermediate_for_float_itype (int32_t, uint32_t, - num); - } else { - numotype_castloop_per_itype (uint32_t, num); - } - break; - case _NNS_FLOAT32: - numotype_castloop_per_itype (float, num); - break; - case _NNS_FLOAT64: - numotype_castloop_per_itype (double, num); - break; - case _NNS_INT64: - numotype_castloop_per_itype (int64_t, num); - break; - case _NNS_UINT64: - if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) { - numotype_castloop_via_intermediate_for_float_itype (int64_t, uint64_t, - num); - } else { - numotype_castloop_per_itype (uint64_t, num); - } - break; - default: - g_assert (0); - return GST_FLOW_ERROR; - } - - return GST_FLOW_OK; -} + tensor_transform_operand_s value; + size_t i, data_idx; -/** - * Macro to run loop for various data types with simple arithmetic which has single operand - */ -#define arith(itype,num,op,a) do { \ - size_t i; \ - itype *in = (itype *) inptr; \ - itype *out = (itype *) outptr; \ - for (i=0;idata_arithmetic.mode]);\ - g_assert(0); \ - }; \ - arith(itype, num, op, a); break; \ -} while (0); - -/** - * Macro to handle the case of dual operands - */ -#define arithmode_dual_oprnd_case(itype,num,mode,op1,op2,value) \ -do {\ - itype a;\ - itype b; \ - switch (value[0].type) {\ - case ARITH_OPRND_TYPE_INT64 : a = (itype) value[0].value_int64; break; \ - case ARITH_OPRND_TYPE_DOUBLE : a = (itype) value[0].value_double; break;\ - default: \ - g_printerr ("The operands required by \'%s\' are not properly provided.\n", \ - gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\ - g_assert(0); \ - }; \ - switch (value[1].type) {\ - case ARITH_OPRND_TYPE_INT64 : b = (itype) value[1].value_int64; break; \ - case ARITH_OPRND_TYPE_DOUBLE : b = (itype) value[1].value_double; break;\ - default: \ - g_printerr ("The operands required by \'%s\' are not properly provided.\n", \ - gst_tensor_transform_arithmetic_string[filter->data_arithmetic.mode]);\ - g_assert(0); \ - }; \ - arith2(itype, num, op1, a, op2, b); break; \ -} while (0); + /* typecast */ + gst_tensor_transform_typecast_value (filter, &value, out_tensor_type); -/** - * Macro to run loop for various data types with simple arithmetic - */ -#define arithloopcase(typecase,itype,num,mode,value) \ - case typecase: \ - { \ - switch (mode) { \ - case ARITH_ADD: {\ - arithmode_single_oprnd_case (itype, num, mode, +, value); \ - break; \ - }; \ - case ARITH_MUL: { \ - arithmode_single_oprnd_case (itype, num, mode, *, value); \ - break; \ - };\ - case ARITH_ADD_MUL: {\ - arithmode_dual_oprnd_case (itype, num, mode, +, *, value); \ - break; \ - }; \ - case ARITH_MUL_ADD: {\ - arithmode_dual_oprnd_case (itype, num, mode, *, +, value); \ - break; \ - }; \ - default: g_assert(0); return GST_FLOW_ERROR; \ - } \ - break; \ + /* set output value */ + g_assert (out_tensor_type == value.type); + data_idx = tensor_element_size[out_tensor_type] * i; + gst_tensor_transform_get_value (filter, &value, + (gpointer) (outptr + data_idx)); } + return GST_FLOW_OK; +} + /** * @brief subrouting for tensor-tranform, "arithmetic" case. * @param[in/out] filter "this" pointer @@ -793,25 +950,64 @@ static GstFlowReturn gst_tensor_transform_arithmetic (GstTensorTransform * filter, const uint8_t * inptr, uint8_t * outptr) { - uint32_t num = get_tensor_element_count (filter->in_config.info.dimension); + size_t num = get_tensor_element_count (filter->in_config.info.dimension); tensor_type in_tensor_type = filter->in_config.info.type; - tensor_transform_arith_mode mode = filter->data_arithmetic.mode; - tensor_transform_arithmetic_operand *value = filter->data_arithmetic.value; - - switch (in_tensor_type) { - arithloopcase (_NNS_INT8, int8_t, num, mode, value); - arithloopcase (_NNS_INT16, int16_t, num, mode, value); - arithloopcase (_NNS_INT32, int32_t, num, mode, value); - arithloopcase (_NNS_UINT8, uint8_t, num, mode, value); - arithloopcase (_NNS_UINT16, uint16_t, num, mode, value); - arithloopcase (_NNS_UINT32, uint32_t, num, mode, value); - arithloopcase (_NNS_FLOAT32, float, num, mode, value); - arithloopcase (_NNS_FLOAT64, double, num, mode, value); - arithloopcase (_NNS_INT64, int64_t, num, mode, value); - arithloopcase (_NNS_UINT64, uint64_t, num, mode, value); - default: - g_assert (0); - return GST_FLOW_ERROR; + tensor_type out_tensor_type = filter->out_config.info.type; + + GSList *walk; + tensor_transform_operator_s *op_s; + tensor_transform_operand_s value; + size_t i, data_idx; + + for (i = 0; i < num; ++i) { + /* init value with input tensor type */ + data_idx = tensor_element_size[in_tensor_type] * i; + gst_tensor_transform_set_value (filter, &value, in_tensor_type, + (gpointer) (inptr + data_idx)); + + walk = filter->operators; + while (walk) { + op_s = (tensor_transform_operator_s *) walk->data; + + /** + * @todo add more options + */ + switch (op_s->op) { + case GTT_OP_TYPECAST: + gst_tensor_transform_typecast_value (filter, &value, + op_s->value.type); + break; + case GTT_OP_ADD: + gst_tensor_transform_typecast_value (filter, &op_s->value, + value.type); + gst_tensor_transform_do_operator (filter, &value, &op_s->value, + GTT_OP_ADD); + break; + case GTT_OP_MUL: + gst_tensor_transform_typecast_value (filter, &op_s->value, + value.type); + gst_tensor_transform_do_operator (filter, &value, &op_s->value, + GTT_OP_MUL); + break; + case GTT_OP_DIV: + gst_tensor_transform_typecast_value (filter, &op_s->value, + value.type); + gst_tensor_transform_do_operator (filter, &value, &op_s->value, + GTT_OP_DIV); + break; + default: + g_assert (0); + return GST_FLOW_ERROR; + } + + walk = g_slist_next (walk); + } + + /* set output value */ + g_assert (out_tensor_type == value.type); + data_idx = tensor_element_size[out_tensor_type] * i; + gst_tensor_transform_get_value (filter, &value, + (gpointer) (outptr + data_idx)); } return GST_FLOW_OK; @@ -970,6 +1166,7 @@ gst_tensor_transform_transform (GstBaseTransform * trans, uint8_t *inptr, *outptr; GstMapInfo inInfo, outInfo; + g_assert (filter->loaded); g_assert (gst_buffer_map (inbuf, &inInfo, GST_MAP_READ)); g_assert (gst_buffer_map (outbuf, &outInfo, GST_MAP_WRITE)); @@ -1041,12 +1238,13 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter, GstPadDirection direction, const GstTensorInfo * in_info, GstTensorInfo * out_info) { + int i; + switch (filter->mode) { case GTT_DIMCHG: out_info->type = in_info->type; if (direction == GST_PAD_SINK) { - int i; int a = filter->data_dimchg.from; int b = filter->data_dimchg.to; @@ -1075,7 +1273,6 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter, } } } else { - int i; int a = filter->data_dimchg.from; int b = filter->data_dimchg.to; @@ -1105,10 +1302,9 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter, } } break; + case GTT_TYPECAST: - { - /** For both directions, dimension does not change */ - int i; + /** For both directions, dimension does not change */ for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) { out_info->dimension[i] = in_info->dimension[i]; } @@ -1120,20 +1316,23 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter, out_info->type = in_info->type; /** @todo this may cause problems with Cap-Transform */ } break; - } + case GTT_ARITHMETIC: - { - int i; for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) { out_info->dimension[i] = in_info->dimension[i]; } out_info->type = in_info->type; + + /* check arith mode option has typecast operator */ + if (direction == GST_PAD_SINK && + filter->data_arithmetic.out_type != _NNS_END) { + out_info->type = filter->data_arithmetic.out_type; + } break; - } + case GTT_TRANSPOSE: - { out_info->type = in_info->type; - int i; + if (direction == GST_PAD_SINK) { for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) { out_info->dimension[i] = @@ -1148,16 +1347,14 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter, } } break; - } + case GTT_STAND: - { - int i; for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) { out_info->dimension[i] = in_info->dimension[i]; } out_info->type = in_info->type; break; - } + default: return FALSE; } diff --git a/gst/tensor_transform/tensor_transform.h b/gst/tensor_transform/tensor_transform.h index 6377512..90c9fd9 100644 --- a/gst/tensor_transform/tensor_transform.h +++ b/gst/tensor_transform/tensor_transform.h @@ -46,8 +46,6 @@ G_BEGIN_DECLS (G_TYPE_CHECK_CLASS_TYPE((klass),GST_TYPE_TENSOR_TRANSFORM)) #define GST_TENSOR_TRANSFORM_CAST(obj) ((GstTensorTransform *)(obj)) -#define ARITH_OPRND_NUM_LIMIT 2 - typedef struct _GstTensorTransform GstTensorTransform; typedef struct _GstTensorTransformClass GstTensorTransformClass; @@ -64,13 +62,13 @@ typedef enum typedef enum { - ARITH_ADD = 0, - ARITH_MUL = 1, - ARITH_ADD_MUL = 2, /* Fused add-multiply */ - ARITH_MUL_ADD = 3, /* Fused multiply-add */ + GTT_OP_TYPECAST = 0, + GTT_OP_ADD = 1, + GTT_OP_MUL = 2, + GTT_OP_DIV = 3, - ARITH_END, -} tensor_transform_arith_mode; + GTT_OP_UNKNOWN +} tensor_transform_operator; typedef enum { @@ -78,14 +76,6 @@ typedef enum STAND_END, } tensor_transform_stand_mode; -typedef enum -{ - ARITH_OPRND_TYPE_INT64 = 0, - ARITH_OPRND_TYPE_DOUBLE = 1, - - ARITH_OPRND_TYPE_END -} tensor_transform_arith_oprnd_type; - /** * @brief Internal data structure for dimchg mode. */ @@ -104,20 +94,26 @@ typedef struct _tensor_transform_typecast { /** * @brief Internal data structure for operand of arithmetic mode. */ -typedef struct _tensor_transform_arithmetic_operand { - tensor_transform_arith_oprnd_type type; - union { - int64_t value_int64; - double value_double; - }; -} tensor_transform_arithmetic_operand; +typedef struct +{ + tensor_type type; + tensor_element data; +} tensor_transform_operand_s; + +/** + * @brief Internal data structure for operator of arithmetic mode. + */ +typedef struct +{ + tensor_transform_operator op; + tensor_transform_operand_s value; +} tensor_transform_operator_s; /** * @brief Internal data structure for arithmetic mode. */ typedef struct _tensor_transform_arithmetic { - tensor_transform_arith_mode mode; - tensor_transform_arithmetic_operand value[ARITH_OPRND_NUM_LIMIT]; + tensor_type out_type; } tensor_transform_arithmetic; /** @@ -153,6 +149,8 @@ struct _GstTensorTransform }; gboolean loaded; /**< TRUE if mode & option are loaded */ + GSList *operators; /**< operators list */ + GstTensorConfig in_config; /**< input tensor info */ GstTensorConfig out_config; /**< output tensor info */ }; diff --git a/tests/transform_arithmetic/runTest.sh b/tests/transform_arithmetic/runTest.sh index 5f0c52e..a28bce6 100644 --- a/tests/transform_arithmetic/runTest.sh +++ b/tests/transform_arithmetic/runTest.sh @@ -45,16 +45,16 @@ python checkResult.py arithmetic testcase02.direct.log testcase02.arithmetic.log testResult $? 2 "Golden test comparison" 0 1 # Test for mul with floating-point operand -gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float32 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul:-5.5 ! filesink location=\"testcase03.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase03.direct.log\" sync=true" 3 0 0 $PERFORMANCE +gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float32 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul:-5.5 ! filesink location=\"testcase03.arithmetic.1.log\" sync=true t. ! queue ! filesink location=\"testcase03.direct.1.log\" sync=true" 3 0 0 $PERFORMANCE -python checkResult.py arithmetic testcase03.direct.log testcase03.arithmetic.log 4 4 f f mul -5.5 0 +python checkResult.py arithmetic testcase03.direct.1.log testcase03.arithmetic.1.log 4 4 f f mul -5.5 0 testResult $? 3 "Golden test comparison" 0 1 -# Fail Test 3-F: for mul with floating-point operand -gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float32 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul::-5.5 ! filesink location=\"testcase03.arithmetic.fail.log\" sync=true t. ! queue ! filesink location=\"testcase03.direct.fail.log\" sync=true" 3-F 0 1 $PERFORMANCE +# Test 3-2 for typecast,mul with floating-point operand +gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tee name=t ! queue ! tensor_transform mode=arithmetic option=typecast:float32,mul:-5.5 ! filesink location=\"testcase03.arithmetic.2.log\" sync=true t. ! queue ! tensor_transform mode=typecast option=float32 ! filesink location=\"testcase03.direct.2.log\" sync=true" 3-2 0 0 $PERFORMANCE -python checkResult.py arithmetic testcase03.direct.fail.log testcase03.arithmetic.fail.log 4 4 f f mul 0 -5.5 -testResult $? 3-F "Golden test comparison" 0 1 +python checkResult.py arithmetic testcase03.direct.2.log testcase03.arithmetic.2.log 4 4 f f mul -5.5 0 +testResult $? 3-2 "Golden test comparison" 0 1 # Test for add with floating-point operand gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add:9.900000e-001 ! filesink location=\"testcase04.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase04.direct.log\" sync=true" 4 0 0 $PERFORMANCE @@ -68,32 +68,14 @@ gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequenc python checkResult.py arithmetic testcase04.direct.ok.log testcase04.arithmetic.ok.log 8 8 d d add 9.900000e-001 -80.256 testResult $? 4-OK "Golden test comparison" 0 1 -# Test for add-mul with floating-point operands -gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul:-9.3:-11.4823e-002 ! filesink location=\"testcase05.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.log\" sync=true" 5 0 0 $PERFORMANCE +# Test for add,mul with floating-point operands +gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add:-9.3,mul:-11.4823e-002 ! filesink location=\"testcase05.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.log\" sync=true" 5 0 0 $PERFORMANCE testResult $? 5 "Golden test comparison" 0 1 python checkResult.py arithmetic testcase05.direct.log testcase05.arithmetic.log 8 8 d d add-mul -9.3 -11.4823e-002 -# Fail Test 5-F1: add-mul with single floating-point operand -gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul:-9.3 ! filesink location=\"testcase05.arithmetic.fail1.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.fail1.log\" sync=true" 5-F1 0 1 $PERFORMANCE - -testResult $? 5-F1 "Golden test comparison" 0 1 -python checkResult.py arithmetic testcase05.direct.fail1.log testcase05.arithmetic.fail1.log 8 8 d d add-mul -9.3 0 - -# Fail Test 5-F2: add-mul with single floating-point operand -gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul:-9.3: ! filesink location=\"testcase05.arithmetic.fail2.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.fail2.log\" sync=true" 5-F2 0 1 $PERFORMANCE - -testResult $? 5-F2 "Golden test comparison" 0 1 -python checkResult.py arithmetic testcase05.direct.fail2.log testcase05.arithmetic.fail2.log 8 8 d d add-mul -9.3 0 - -# Fail Test 5-F3: add-mul with single floating-point operand -gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=add-mul::-11.4823e-002 ! filesink location=\"testcase05.arithmetic.fail3.log\" sync=true t. ! queue ! filesink location=\"testcase05.direct.fail3.log\" sync=true" 5-F3 0 1 $PERFORMANCE - -testResult $? 5-F3 "Golden test comparison" 0 1 -python checkResult.py arithmetic testcase05.direct.fail3.log testcase05.arithmetic.fail3.log 8 8 d d add-mul 30 -11.4823e-002 - -# Test for mul-add with floating-point operands -gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul-add:-50.0987e+003:15.3 ! filesink location=\"testcase06.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase06.direct.log\" sync=true" 6 0 0 $PERFORMANCE +# Test for mul,add with floating-point operands +gstTest "--gst-plugin-path=${PATH_TO_PLUGIN} multifilesrc location=\"testsequence_%1d.png\" index=0 caps=\"image/png,framerate=\(fraction\)30/1\" ! pngdec ! videoconvert ! video/x-raw, format=RGB ! tensor_converter ! tensor_transform mode=typecast option=float64 ! tee name=t ! queue ! tensor_transform mode=arithmetic option=mul:-50.0987e+003,add:15.3 ! filesink location=\"testcase06.arithmetic.log\" sync=true t. ! queue ! filesink location=\"testcase06.direct.log\" sync=true" 6 0 0 $PERFORMANCE testResult $? 6 "Golden test comparison" 0 1 python checkResult.py arithmetic testcase06.direct.log testcase06.arithmetic.log 8 8 d d add-mul -50.0987e+003 15.3 -- 2.7.4