#define GST_CAT_DEFAULT gst_tensor_transform_debug
#define CAPS_STRING GST_TENSOR_CAP_DEFAULT ";" GST_TENSORS_CAP_MAKE ("{ static, flexible }")
#define REGEX_DIMCHG_OPTION "^([0-3]):([0-3])$"
-#define REGEX_TYPECAST_OPTION "(^[u]?int(8|16|32|64)$|^float(32|64)$)"
+#define REGEX_TYPECAST_OPTION "(^[u]?int(8|16|32|64)$|^float(16|32|64)$)"
#define REGEX_TRANSPOSE_OPTION "^(?:([0-2]):(?!.*\\1)){3}3$"
-#define REGEX_STAND_OPTION "^(default|dc-average)(:([u]?int(8|16|32|64)|float(32|64)))?(,per-channel:(true|false))?$"
+#define REGEX_STAND_OPTION "^(default|dc-average)(:([u]?int(8|16|32|64)|float(16|32|64)))?(,per-channel:(true|false))?$"
#define REGEX_CLAMP_OPTION "^((([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?))):"\
"((([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)))$"
-#define REGEX_ARITH_OPTION "^(typecast:([u]?int(8|16|32|64)|float(32|64)),)?"\
+#define REGEX_ARITH_OPTION "^(typecast:([u]?int(8|16|32|64)|float(16|32|64)),)?"\
"(per-channel:(false|true@[0-9]+),)?"\
"(((add|mul|div)(:([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?))+(@[0-9]+)?)(,|))+$"
-#define REGEX_ARITH_OPTION_TYPECAST "(typecast:([u]?int(8|16|32|64)|float(32|64)))"
+#define REGEX_ARITH_OPTION_TYPECAST "(typecast:([u]?int(8|16|32|64)|float(16|32|64)))"
/**
* @brief tensor_transform properties
return (index < 0) ? STAND_END : index;
}
+#ifndef FLOAT16_SUPPORT
+/**
+ * @brief Generate error if float16 is required.
+ */
+static void
+float16_not_supported (void)
+{
+ ml_loge
+ ("Tensor_tranform does not support float16 operators. Apply -Denable-float16=true for meson build option if your architecture support float16. Note that tensor-transform's float16 is adhoc and does NOT perform good (slow!).\n");
+ g_assert (0);
+}
+#endif
+
+#ifdef FLOAT16_SUPPORT
+/** @todo Remove this after applying SIMD or ORC */
+static void
+refrain_from_heavy_op_on_float16 (gulong n)
+{
+ static int warned = 0;
+ /* 1 million */
+ if (n > 1000000) {
+ if (warned)
+ return;
+ ml_logw
+ ("Tensor_transform implementation for float16 does not support SIMD. Heavy tensor-transform operations of float16 is not recommended. Try to apply heavy ops with other types (e.g., float32) and convert it to float16 at the time when it's really needed.\n");
+ warned = 1;
+ }
+}
+
+/** @todo Make this use SIMD or ORC */
+#define _conv_to_f16(intype, o, i, n) \
+ do { \
+ float16 *op = (gpointer) (o); \
+ intype *ip = (gpointer) (i); \
+ gulong idx; \
+ refrain_from_heavy_op_on_float16 (n); \
+ for (idx = 0; idx < n; idx++) \
+ *(op + idx) = (float16) *(ip + idx); \
+ } while (0)
+
+/** @todo Make this use SIMD or ORC */
+#define _conv_from_f16_action(n, op, ip, otypename) \
+ do { \
+ gulong idx; \
+ for (idx = 0; idx < n; idx++) \
+ *(op + idx) = (otypename) *(ip + idx); \
+ } while (0)
+
+/** @todo Make this use SIMD or ORC */
+#define _conv_from_f16(otype, o, i, n) \
+ do { \
+ float16 *ip = (gpointer) (i); \
+ refrain_from_heavy_op_on_float16 (n); \
+ switch (otype) { \
+ case _NNS_INT32: { \
+ int32_t *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, int32_t); \
+ break; } \
+ case _NNS_UINT32: { \
+ uint32_t *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, uint32_t); \
+ break; } \
+ case _NNS_INT16: { \
+ int16_t *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, int16_t); \
+ break; } \
+ case _NNS_UINT16: { \
+ uint16_t *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, uint16_t); \
+ break; } \
+ case _NNS_INT8: { \
+ int8_t *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, int8_t); \
+ break; } \
+ case _NNS_UINT8: { \
+ uint8_t *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, uint8_t); \
+ break; } \
+ case _NNS_FLOAT64: { \
+ double *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, double); \
+ break; } \
+ case _NNS_FLOAT32: { \
+ float *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, float); \
+ break; } \
+ case _NNS_FLOAT16: { \
+ float16 *op = (gpointer) (o); \
+ _conv_from_f16_action (n, op, ip, float16); \
+ break; } \
+ default: GST_ERROR_OBJECT (filter, "Unsupported type %d", (otype)); g_assert (0); \
+ } \
+ } while (0)
+
+/** @todo Make this use SIMD or ORC */
+#define _op_float16(i, n, v, op) \
+ do { \
+ gulong idx; \
+ float16 *data_in = (float16 *) (i); \
+ refrain_from_heavy_op_on_float16 (n); \
+ switch (op) { \
+ case GTT_OP_ADD: \
+ for (idx = 0; idx < n; idx++) \
+ data_in[idx] = data_in[idx] + (v); \
+ break; \
+ case GTT_OP_MUL: \
+ for (idx = 0; idx < n; idx++) \
+ data_in[idx] = data_in[idx] * (v); \
+ break; \
+ case GTT_OP_DIV: \
+ for (idx = 0; idx < n; idx++) \
+ data_in[idx] = data_in[idx] / (v); \
+ break; \
+ default: GST_ERROR_OBJECT (filter, "Unknown operator for float16: %d", op); break; \
+ } \
+ } while (0)
+
+#else /* ! FLOAT16_SUPPORT */
+#define _conv_to_f16(intype, o, i, n) do { float16_not_supported (); } while (0)
+#define _conv_from_f16(otype, o, i, n) do { float16_not_supported (); } while (0)
+#define _op_float16(i, n, v, op) do { float16_not_supported (); } while (0)
+#endif /* FLOAT16_SUPPORT */
+
#ifdef HAVE_ORC
/* define macros for orc */
/** @todo support 64bit integer and remove below line */
#define orc_func_mul(intype) nns_orc_mul_c_ ## intype
#define orc_func_div(intype) nns_orc_div_c_ ## intype
-#define orc_typecast_to(i,o,n,intype,otype) do { \
+#define orc_typecast_to(i,o,n,intype,otype,intypename) do { \
switch (otype) { \
case _NNS_INT32: orc_func_conv (intype, s32) ((gpointer) o, (gpointer) i, n); break; \
case _NNS_UINT32: orc_func_conv (intype, u32) ((gpointer) o, (gpointer) i, n); break; \
case _NNS_UINT8: orc_func_conv (intype, u8) ((gpointer) o, (gpointer) i, n); break; \
case _NNS_FLOAT64: orc_func_conv (intype, f64) ((gpointer) o, (gpointer) i, n); break; \
case _NNS_FLOAT32: orc_func_conv (intype, f32) ((gpointer) o, (gpointer) i, n); break; \
+ case _NNS_FLOAT16: _conv_to_f16 (intypename, o, i, n); break; \
default: GST_ERROR_OBJECT (filter, "Unsupported output type %d", otype); g_assert (0); break; \
} \
} while (0)
#define orc_typecast(i,o,n,itype,otype) do { \
switch (itype) { \
- case _NNS_INT32: orc_typecast_to (i, o, n, s32, otype); break; \
- case _NNS_UINT32: orc_typecast_to (i, o, n, u32, otype); break; \
- case _NNS_INT16: orc_typecast_to (i, o, n, s16, otype); break; \
- case _NNS_UINT16: orc_typecast_to (i, o, n, u16, otype); break; \
- case _NNS_INT8: orc_typecast_to (i, o, n, s8, otype); break; \
- case _NNS_UINT8: orc_typecast_to (i, o, n, u8, otype); break; \
- case _NNS_FLOAT64: orc_typecast_to (i, o, n, f64, otype); break; \
- case _NNS_FLOAT32: orc_typecast_to (i, o, n, f32, otype); break; \
+ case _NNS_INT32: orc_typecast_to (i, o, n, s32, otype, int32_t); break; \
+ case _NNS_UINT32: orc_typecast_to (i, o, n, u32, otype, uint32_t); break; \
+ case _NNS_INT16: orc_typecast_to (i, o, n, s16, otype, int16_t); break; \
+ case _NNS_UINT16: orc_typecast_to (i, o, n, u16, otype, uint16_t); break; \
+ case _NNS_INT8: orc_typecast_to (i, o, n, s8, otype, int8_t); break; \
+ case _NNS_UINT8: orc_typecast_to (i, o, n, u8, otype, uint8_t); break; \
+ case _NNS_FLOAT64: orc_typecast_to (i, o, n, f64, otype, double); break; \
+ case _NNS_FLOAT32: orc_typecast_to (i, o, n, f32, otype, float); break; \
+ case _NNS_FLOAT16: _conv_from_f16 (otype, o, i, n); break; \
default: GST_ERROR_OBJECT (filter, "Unsupported input type %d", itype); g_assert (0); break; \
} \
} while (0)
-#define orc_operator_func(i,n,v,opfunc) do { \
+#define orc_operator_func(i,n,v,opfunc,op) do { \
switch ((v)->type) { \
case _NNS_INT32: opfunc (s32) ((gpointer) i, (v)->data._int32_t, n); break; \
case _NNS_UINT32: opfunc (u32) ((gpointer) i, (v)->data._uint32_t, n); break; \
case _NNS_UINT8: opfunc (u8) ((gpointer) i, (v)->data._uint8_t, n); break; \
case _NNS_FLOAT64: opfunc (f64) ((gpointer) i, (v)->data._double, n); break; \
case _NNS_FLOAT32: opfunc (f32) ((gpointer) i, (v)->data._float, n); break; \
+ case _NNS_FLOAT16: _op_float16 (i, n, (v)->data._float16, op); break; \
default: GST_ERROR_OBJECT (filter, "Unsupported type %d", (v)->type); g_assert (0); break; \
} \
} while (0)
#define orc_operator(i,n,v,op) do { \
switch (op) { \
- case GTT_OP_ADD: orc_operator_func (i, n, v, orc_func_add); break; \
- case GTT_OP_MUL: orc_operator_func (i, n, v, orc_func_mul); break; \
+ case GTT_OP_ADD: orc_operator_func (i, n, v, orc_func_add, op); break; \
+ case GTT_OP_MUL: orc_operator_func (i, n, v, orc_func_mul, op); break; \
case GTT_OP_DIV: \
switch ((v)->type) { \
case _NNS_INT32: orc_operator_div_loop (i, n, (v)->data._int32_t, int32_t); break; \
case _NNS_UINT8: orc_operator_div_loop (i, n, (v)->data._uint8_t, uint8_t); break; \
case _NNS_FLOAT64: orc_func_div (f64) ((gpointer) i, (v)->data._double, n); break; \
case _NNS_FLOAT32: orc_func_div (f32) ((gpointer) i, (v)->data._float, n); break; \
+ case _NNS_FLOAT16: _op_float16 (i, n, (v)->data._float16, op); break; \
default: GST_ERROR_OBJECT (filter, "Unsupported type %d", (v)->type); g_assert (0); break; \
} \
break; \
case _NNS_FLOAT32:
handle_operator (desc, val, op, float);
break;
+ case _NNS_FLOAT16:
+#ifdef FLOAT16_SUPPORT
+ handle_operator (desc, val, op, float16);
+#else
+ float16_not_supported ();
+#endif
+ break;
case _NNS_INT64:
handle_operator (desc, val, op, int64_t);
break;