tensor-transform: float16 support.
authorMyungJoo Ham <myungjoo.ham@samsung.com>
Fri, 17 Jun 2022 04:26:33 +0000 (13:26 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Sun, 3 Jul 2022 04:42:13 +0000 (13:42 +0900)
ORC/SIMD support not included.
Float16 w/ heavy tensor-transform is not recommended.

- Generate error if float16 is used in unsupported build
- Generate warning if float16 op is used too heavily (it's slow!)
- Future todo: support SIMD for float16 ops

Signed-off-by: MyungJoo Ham <myungjoo.ham@samsung.com>
gst/nnstreamer/tensor_data.c
gst/nnstreamer/tensor_transform/tensor_transform.c

index 980cc8c..05a5603 100644 (file)
@@ -208,7 +208,8 @@ gst_tensor_data_typecast (tensor_data_s * td, tensor_type type)
 
   /* do nothing when transform to same type */
   if (td->type != type) {
-    is_float = (td->type == _NNS_FLOAT32 || td->type == _NNS_FLOAT64 || td->type == _NNS_FLOAT16);
+    is_float = (td->type == _NNS_FLOAT32 || td->type == _NNS_FLOAT64
+        || td->type == _NNS_FLOAT16);
 
     switch (type) {
       case _NNS_INT32:
@@ -251,7 +252,8 @@ gst_tensor_data_typecast (tensor_data_s * td, tensor_type type)
 #ifdef FLOAT16_SUPPORT
         td_typecast (td, float16);
 #else
-        nns_loge ("NNStreamer requires -DFLOAT16_SUPPORT as a build option to enable float16 type. This binary does not have float16 feature enabled; thus, float16 type is not supported in this instance.\n");
+        nns_loge
+            ("NNStreamer requires -DFLOAT16_SUPPORT as a build option to enable float16 type. This binary does not have float16 feature enabled; thus, float16 type is not supported in this instance.\n");
         return FALSE;
 #endif
         break;
index 41b649e..d83c61f 100644 (file)
@@ -68,16 +68,16 @@ GST_DEBUG_CATEGORY_STATIC (gst_tensor_transform_debug);
 #define GST_CAT_DEFAULT gst_tensor_transform_debug
 #define CAPS_STRING GST_TENSOR_CAP_DEFAULT ";" GST_TENSORS_CAP_MAKE ("{ static, flexible }")
 #define REGEX_DIMCHG_OPTION "^([0-3]):([0-3])$"
-#define REGEX_TYPECAST_OPTION "(^[u]?int(8|16|32|64)$|^float(32|64)$)"
+#define REGEX_TYPECAST_OPTION "(^[u]?int(8|16|32|64)$|^float(16|32|64)$)"
 #define REGEX_TRANSPOSE_OPTION "^(?:([0-2]):(?!.*\\1)){3}3$"
-#define REGEX_STAND_OPTION "^(default|dc-average)(:([u]?int(8|16|32|64)|float(32|64)))?(,per-channel:(true|false))?$"
+#define REGEX_STAND_OPTION "^(default|dc-average)(:([u]?int(8|16|32|64)|float(16|32|64)))?(,per-channel:(true|false))?$"
 #define REGEX_CLAMP_OPTION "^((([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?))):"\
     "((([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)))$"
-#define REGEX_ARITH_OPTION "^(typecast:([u]?int(8|16|32|64)|float(32|64)),)?"\
+#define REGEX_ARITH_OPTION "^(typecast:([u]?int(8|16|32|64)|float(16|32|64)),)?"\
     "(per-channel:(false|true@[0-9]+),)?"\
     "(((add|mul|div)(:([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?))+(@[0-9]+)?)(,|))+$"
 
-#define REGEX_ARITH_OPTION_TYPECAST "(typecast:([u]?int(8|16|32|64)|float(32|64)))"
+#define REGEX_ARITH_OPTION_TYPECAST "(typecast:([u]?int(8|16|32|64)|float(16|32|64)))"
 
 /**
  * @brief tensor_transform properties
@@ -318,6 +318,129 @@ gst_tensor_transform_get_stand_mode (const gchar * str)
   return (index < 0) ? STAND_END : index;
 }
 
+#ifndef FLOAT16_SUPPORT
+/**
+ * @brief Generate error if float16 is required.
+ */
+static void
+float16_not_supported (void)
+{
+  ml_loge
+      ("Tensor_tranform does not support float16 operators. Apply -Denable-float16=true for meson build option if your architecture support float16. Note that tensor-transform's float16 is adhoc and does NOT perform good (slow!).\n");
+  g_assert (0);
+}
+#endif
+
+#ifdef FLOAT16_SUPPORT
+/** @todo Remove this after applying SIMD or ORC */
+static void
+refrain_from_heavy_op_on_float16 (gulong n)
+{
+  static int warned = 0;
+  /* 1 million */
+  if (n > 1000000) {
+    if (warned)
+      return;
+    ml_logw
+        ("Tensor_transform implementation for float16 does not support SIMD. Heavy tensor-transform operations of float16 is not recommended. Try to apply heavy ops with other types (e.g., float32) and convert it to float16 at the time when it's really needed.\n");
+    warned = 1;
+  }
+}
+
+/** @todo Make this use SIMD or ORC */
+#define _conv_to_f16(intype, o, i, n) \
+  do { \
+    float16 *op = (gpointer) (o); \
+    intype *ip = (gpointer) (i); \
+    gulong idx; \
+    refrain_from_heavy_op_on_float16 (n); \
+    for (idx = 0; idx < n; idx++) \
+      *(op + idx) = (float16) *(ip + idx); \
+  } while (0)
+
+/** @todo Make this use SIMD or ORC */
+#define _conv_from_f16_action(n, op, ip, otypename) \
+  do { \
+    gulong idx; \
+    for (idx = 0; idx < n; idx++) \
+      *(op + idx) = (otypename) *(ip + idx); \
+  } while (0)
+
+/** @todo Make this use SIMD or ORC */
+#define _conv_from_f16(otype, o, i, n) \
+  do { \
+    float16 *ip = (gpointer) (i); \
+    refrain_from_heavy_op_on_float16 (n); \
+    switch (otype) { \
+      case _NNS_INT32: { \
+        int32_t *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, int32_t); \
+        break; } \
+      case _NNS_UINT32: {  \
+        uint32_t *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, uint32_t); \
+        break; } \
+      case _NNS_INT16: {  \
+        int16_t *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, int16_t); \
+        break; } \
+      case _NNS_UINT16: {  \
+        uint16_t *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, uint16_t); \
+        break; } \
+      case _NNS_INT8: {  \
+        int8_t *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, int8_t); \
+        break; } \
+      case _NNS_UINT8: {  \
+        uint8_t *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, uint8_t); \
+        break; } \
+      case _NNS_FLOAT64: {  \
+        double *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, double); \
+        break; } \
+      case _NNS_FLOAT32: {  \
+        float *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, float); \
+        break; } \
+      case _NNS_FLOAT16: {  \
+        float16 *op = (gpointer) (o); \
+        _conv_from_f16_action (n, op, ip, float16); \
+        break; } \
+      default: GST_ERROR_OBJECT (filter, "Unsupported type %d", (otype)); g_assert (0); \
+    } \
+  } while (0)
+
+/** @todo Make this use SIMD or ORC */
+#define _op_float16(i, n, v, op) \
+  do { \
+    gulong idx; \
+    float16 *data_in = (float16 *) (i); \
+    refrain_from_heavy_op_on_float16 (n); \
+    switch (op) { \
+      case GTT_OP_ADD: \
+        for (idx = 0; idx < n; idx++) \
+          data_in[idx] = data_in[idx] + (v); \
+        break; \
+      case GTT_OP_MUL: \
+        for (idx = 0; idx < n; idx++) \
+          data_in[idx] = data_in[idx] * (v); \
+        break; \
+      case GTT_OP_DIV: \
+        for (idx = 0; idx < n; idx++) \
+          data_in[idx] = data_in[idx] / (v); \
+        break; \
+      default: GST_ERROR_OBJECT (filter, "Unknown operator for float16: %d", op); break; \
+    } \
+  } while (0)
+
+#else /* ! FLOAT16_SUPPORT */
+#define _conv_to_f16(intype, o, i, n) do { float16_not_supported (); } while (0)
+#define _conv_from_f16(otype, o, i, n) do { float16_not_supported (); } while (0)
+#define _op_float16(i, n, v, op) do { float16_not_supported (); } while (0)
+#endif /* FLOAT16_SUPPORT */
+
 #ifdef HAVE_ORC
 /* define macros for orc */
 /** @todo support 64bit integer and remove below line */
@@ -329,7 +452,7 @@ gst_tensor_transform_get_stand_mode (const gchar * str)
 #define orc_func_mul(intype) nns_orc_mul_c_ ## intype
 #define orc_func_div(intype) nns_orc_div_c_ ## intype
 
-#define orc_typecast_to(i,o,n,intype,otype) do { \
+#define orc_typecast_to(i,o,n,intype,otype,intypename) do { \
     switch (otype) { \
       case _NNS_INT32: orc_func_conv (intype, s32) ((gpointer) o, (gpointer) i, n); break; \
       case _NNS_UINT32: orc_func_conv (intype, u32) ((gpointer) o, (gpointer) i, n); break; \
@@ -339,25 +462,27 @@ gst_tensor_transform_get_stand_mode (const gchar * str)
       case _NNS_UINT8: orc_func_conv (intype, u8) ((gpointer) o, (gpointer) i, n); break; \
       case _NNS_FLOAT64: orc_func_conv (intype, f64) ((gpointer) o, (gpointer) i, n); break; \
       case _NNS_FLOAT32: orc_func_conv (intype, f32) ((gpointer) o, (gpointer) i, n); break; \
+      case _NNS_FLOAT16: _conv_to_f16 (intypename, o, i, n); break; \
       default: GST_ERROR_OBJECT (filter, "Unsupported output type %d", otype); g_assert (0); break; \
     } \
   } while (0)
 
 #define orc_typecast(i,o,n,itype,otype) do { \
     switch (itype) { \
-      case _NNS_INT32: orc_typecast_to (i, o, n, s32, otype); break; \
-      case _NNS_UINT32: orc_typecast_to (i, o, n, u32, otype); break; \
-      case _NNS_INT16: orc_typecast_to (i, o, n, s16, otype); break; \
-      case _NNS_UINT16: orc_typecast_to (i, o, n, u16, otype); break; \
-      case _NNS_INT8: orc_typecast_to (i, o, n, s8, otype); break; \
-      case _NNS_UINT8: orc_typecast_to (i, o, n, u8, otype); break; \
-      case _NNS_FLOAT64: orc_typecast_to (i, o, n, f64, otype); break; \
-      case _NNS_FLOAT32: orc_typecast_to (i, o, n, f32, otype); break; \
+      case _NNS_INT32: orc_typecast_to (i, o, n, s32, otype, int32_t); break; \
+      case _NNS_UINT32: orc_typecast_to (i, o, n, u32, otype, uint32_t); break; \
+      case _NNS_INT16: orc_typecast_to (i, o, n, s16, otype, int16_t); break; \
+      case _NNS_UINT16: orc_typecast_to (i, o, n, u16, otype, uint16_t); break; \
+      case _NNS_INT8: orc_typecast_to (i, o, n, s8, otype, int8_t); break; \
+      case _NNS_UINT8: orc_typecast_to (i, o, n, u8, otype, uint8_t); break; \
+      case _NNS_FLOAT64: orc_typecast_to (i, o, n, f64, otype, double); break; \
+      case _NNS_FLOAT32: orc_typecast_to (i, o, n, f32, otype, float); break; \
+      case _NNS_FLOAT16: _conv_from_f16 (otype, o, i, n); break; \
       default: GST_ERROR_OBJECT (filter, "Unsupported input type %d", itype); g_assert (0); break; \
     } \
   } while (0)
 
-#define orc_operator_func(i,n,v,opfunc) do { \
+#define orc_operator_func(i,n,v,opfunc,op) do { \
     switch ((v)->type) { \
       case _NNS_INT32: opfunc (s32) ((gpointer) i, (v)->data._int32_t, n); break; \
       case _NNS_UINT32: opfunc (u32) ((gpointer) i, (v)->data._uint32_t, n); break; \
@@ -367,6 +492,7 @@ gst_tensor_transform_get_stand_mode (const gchar * str)
       case _NNS_UINT8: opfunc (u8) ((gpointer) i, (v)->data._uint8_t, n); break; \
       case _NNS_FLOAT64: opfunc (f64) ((gpointer) i, (v)->data._double, n); break; \
       case _NNS_FLOAT32: opfunc (f32) ((gpointer) i, (v)->data._float, n); break; \
+      case _NNS_FLOAT16: _op_float16 (i, n, (v)->data._float16, op); break; \
       default: GST_ERROR_OBJECT (filter, "Unsupported type %d", (v)->type); g_assert (0); break; \
     } \
   } while (0)
@@ -381,8 +507,8 @@ gst_tensor_transform_get_stand_mode (const gchar * str)
 
 #define orc_operator(i,n,v,op) do { \
     switch (op) { \
-      case GTT_OP_ADD: orc_operator_func (i, n, v, orc_func_add); break; \
-      case GTT_OP_MUL: orc_operator_func (i, n, v, orc_func_mul); break; \
+      case GTT_OP_ADD: orc_operator_func (i, n, v, orc_func_add, op); break; \
+      case GTT_OP_MUL: orc_operator_func (i, n, v, orc_func_mul, op); break; \
       case GTT_OP_DIV: \
         switch ((v)->type) { \
           case _NNS_INT32: orc_operator_div_loop (i, n, (v)->data._int32_t, int32_t); break; \
@@ -393,6 +519,7 @@ gst_tensor_transform_get_stand_mode (const gchar * str)
           case _NNS_UINT8: orc_operator_div_loop (i, n, (v)->data._uint8_t, uint8_t); break; \
           case _NNS_FLOAT64: orc_func_div (f64) ((gpointer) i, (v)->data._double, n); break; \
           case _NNS_FLOAT32: orc_func_div (f32) ((gpointer) i, (v)->data._float, n); break; \
+          case _NNS_FLOAT16: _op_float16 (i, n, (v)->data._float16, op); break; \
           default: GST_ERROR_OBJECT (filter, "Unsupported type %d", (v)->type); g_assert (0); break; \
         } \
         break; \
@@ -467,6 +594,13 @@ gst_tensor_transform_do_operator (GstTensorTransform * filter,
     case _NNS_FLOAT32:
       handle_operator (desc, val, op, float);
       break;
+    case _NNS_FLOAT16:
+#ifdef FLOAT16_SUPPORT
+      handle_operator (desc, val, op, float16);
+#else
+      float16_not_supported ();
+#endif
+      break;
     case _NNS_INT64:
       handle_operator (desc, val, op, int64_t);
       break;