[Transform] refactor with common tensor info
authorJaeyun <jy1210.jung@samsung.com>
Wed, 7 Nov 2018 11:39:07 +0000 (20:39 +0900)
committerMyungJoo Ham <myungjoo.ham@gmail.com>
Thu, 8 Nov 2018 07:16:21 +0000 (16:16 +0900)
change code to use common structure for in/out tensor info

Signed-off-by: Jaeyun Jung <jy1210.jung@samsung.com>
gst/tensor_transform/tensor_transform.c
gst/tensor_transform/tensor_transform.h

index 610c1e2..22fe903 100644 (file)
@@ -226,7 +226,9 @@ gst_tensor_transform_init (GstTensorTransform * filter)
   filter->mode = GTT_END;
   filter->option = NULL;
   filter->loaded = FALSE;
-  filter->type = _NNS_END;
+
+  gst_tensor_config_init (&filter->in_config);
+  gst_tensor_config_init (&filter->out_config);
 }
 
 /**
@@ -476,20 +478,20 @@ gst_tensor_transform_dimchg (GstTensorTransform * filter,
     const uint8_t * inptr, uint8_t * outptr)
 {
   /** @todo NYI */
-  uint32_t *fromDim = filter->fromDim;
-  uint32_t *toDim = filter->toDim;
+  uint32_t *fromDim = filter->in_config.info.dimension;
+  uint32_t *toDim = filter->out_config.info.dimension;
+  tensor_type in_tensor_type = filter->in_config.info.type;
   int from = filter->data_dimchg.from;
   int to = filter->data_dimchg.to;
   int i, j, k;
   unsigned int loopLimit = 1;
-  size_t loopBlockSize = tensor_element_size[filter->type];
-  size_t copyblocksize = tensor_element_size[filter->type];
+  size_t loopBlockSize = tensor_element_size[in_tensor_type];
+  size_t copyblocksize = tensor_element_size[in_tensor_type];
   size_t copyblocklimit = 1;
 
   if (from == to) {
     /** Useless memcpy. Do not call this or @todo do "IP" operation */
-    memcpy (outptr, inptr, get_tensor_element_count (filter->fromDim) *
-        tensor_element_size[filter->type]);
+    memcpy (outptr, inptr, gst_tensor_info_get_size (&filter->in_config.info));
     g_printerr
         ("Calling tensor_transform with high memcpy overhead WITHOUT any effects! Check your stream wheter you really need tensor_transform.\n");
     return GST_FLOW_OK;
@@ -586,7 +588,7 @@ gst_tensor_transform_dimchg (GstTensorTransform * filter,
  * This is for cases otype is numeral.
  */
 #define numotype_castloop_per_itype(otype,num) do { \
-    switch (filter->type) { \
+    switch (in_tensor_type) { \
     case _NNS_INT8: castloop(int8_t, otype, num); break; \
     case _NNS_INT16: castloop(int16_t, otype, num); break; \
     case _NNS_INT32: castloop(int32_t, otype, num); break; \
@@ -602,7 +604,7 @@ gst_tensor_transform_dimchg (GstTensorTransform * filter,
   } while (0)
 
 #define numotype_castloop_via_intermediate_for_float_itype(mtype, otype, num) do { \
-    switch (filter->type) { \
+    switch (in_tensor_type) { \
      case _NNS_FLOAT32:\
       castloop_via_intermediate(float, mtype, otype, num); \
       break; \
@@ -624,7 +626,8 @@ static GstFlowReturn
 gst_tensor_transform_typecast (GstTensorTransform * filter,
     const uint8_t * inptr, uint8_t * outptr)
 {
-  uint32_t num = get_tensor_element_count (filter->fromDim);
+  uint32_t num = get_tensor_element_count (filter->in_config.info.dimension);
+  tensor_type in_tensor_type = filter->in_config.info.type;
 
   switch (filter->data_typecast.to) {
     case _NNS_INT8:
@@ -637,7 +640,7 @@ gst_tensor_transform_typecast (GstTensorTransform * filter,
       numotype_castloop_per_itype (int32_t, num);
       break;
     case _NNS_UINT8:
-      if ((filter->type == _NNS_FLOAT32) || (filter->type == _NNS_FLOAT64)) {
+      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
         numotype_castloop_via_intermediate_for_float_itype (int8_t, uint8_t,
             num);
       } else {
@@ -645,7 +648,7 @@ gst_tensor_transform_typecast (GstTensorTransform * filter,
       }
       break;
     case _NNS_UINT16:
-      if ((filter->type == _NNS_FLOAT32) || (filter->type == _NNS_FLOAT64)) {
+      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
         numotype_castloop_via_intermediate_for_float_itype (int16_t, uint16_t,
             num);
       } else {
@@ -653,7 +656,7 @@ gst_tensor_transform_typecast (GstTensorTransform * filter,
       }
       break;
     case _NNS_UINT32:
-      if ((filter->type == _NNS_FLOAT32) || (filter->type == _NNS_FLOAT64)) {
+      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
         numotype_castloop_via_intermediate_for_float_itype (int32_t, uint32_t,
             num);
       } else {
@@ -670,7 +673,7 @@ gst_tensor_transform_typecast (GstTensorTransform * filter,
       numotype_castloop_per_itype (int64_t, num);
       break;
     case _NNS_UINT64:
-      if ((filter->type == _NNS_FLOAT32) || (filter->type == _NNS_FLOAT64)) {
+      if ((in_tensor_type == _NNS_FLOAT32) || (in_tensor_type == _NNS_FLOAT64)) {
         numotype_castloop_via_intermediate_for_float_itype (int64_t, uint64_t,
             num);
       } else {
@@ -790,11 +793,12 @@ static GstFlowReturn
 gst_tensor_transform_arithmetic (GstTensorTransform * filter,
     const uint8_t * inptr, uint8_t * outptr)
 {
-  uint32_t num = get_tensor_element_count (filter->fromDim);
+  uint32_t num = get_tensor_element_count (filter->in_config.info.dimension);
+  tensor_type in_tensor_type = filter->in_config.info.type;
   tensor_transform_arith_mode mode = filter->data_arithmetic.mode;
   tensor_transform_arithmetic_operand *value = filter->data_arithmetic.value;
 
-  switch (filter->type) {
+  switch (in_tensor_type) {
       arithloopcase (_NNS_INT8, int8_t, num, mode, value);
       arithloopcase (_NNS_INT16, int16_t, num, mode, value);
       arithloopcase (_NNS_INT32, int32_t, num, mode, value);
@@ -844,8 +848,9 @@ gst_tensor_transform_transpose (GstTensorTransform * filter,
 {
   int i, from, to;
   gboolean checkdim = FALSE;
-  uint32_t *fromDim = filter->fromDim;
-  size_t type_size = tensor_element_size[filter->type];
+  uint32_t *fromDim = filter->in_config.info.dimension;
+  tensor_type in_tensor_type = filter->in_config.info.type;
+  size_t type_size = tensor_element_size[in_tensor_type];
   size_t indexI, indexJ, SL, SI, SJ, SK;
   for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) {
     from = i;
@@ -857,8 +862,7 @@ gst_tensor_transform_transpose (GstTensorTransform * filter,
   }
 
   if (!checkdim) {
-    memcpy (outptr, inptr,
-        get_tensor_element_count (filter->fromDim) * type_size);
+    memcpy (outptr, inptr, gst_tensor_info_get_size (&filter->in_config.info));
     g_printerr
         ("Calling tensor_transform with high memcpy overhead WITHOUT any effects!");
     return GST_FLOW_OK;
@@ -911,7 +915,7 @@ gst_tensor_transform_stand (GstTensorTransform * filter,
 {
   int i;
   size_t Size;
-  uint32_t *fromDim = filter->fromDim;
+  uint32_t *fromDim = filter->in_config.info.dimension;
   double average, stand;
 
   float *in = (float *) inptr;
@@ -1113,7 +1117,7 @@ gst_tensor_transform_convert_dimension (GstTensorTransform * filter,
         out_info->type = filter->data_typecast.to;
       } else {
           /** src = SRCPAD / dest = SINKPAD */
-        out_info->type = filter->type;   /** @todo this may cause problems with Cap-Transform */
+        out_info->type = in_info->type;   /** @todo this may cause problems with Cap-Transform */
       }
       break;
     }
@@ -1181,7 +1185,6 @@ gst_tensor_transform_transform_caps (GstBaseTransform * trans,
   GstCaps *result = NULL;
 
   filter = GST_TENSOR_TRANSFORM_CAST (trans);
-  g_assert (filter->loaded);
 
   silent_debug ("Calling TransformCaps, direction = %d\n", direction);
   silent_debug_caps (caps, "from");
@@ -1190,29 +1193,20 @@ gst_tensor_transform_transform_caps (GstBaseTransform * trans,
   gst_tensor_config_init (&in_config);
   gst_tensor_config_init (&out_config);
 
-  if (direction == GST_PAD_SINK) {
-    if (gst_tensor_transform_read_caps (caps, &in_config)) {
-      gst_tensor_transform_convert_dimension (filter, direction,
-          &in_config.info, &out_config.info);
-    }
-
-    /**
-     * supposed same framerate from input configuration
-     */
-    out_config.rate_n = in_config.rate_n;
-    out_config.rate_d = in_config.rate_d;
+  if (gst_tensor_transform_read_caps (caps, &in_config)) {
+    gst_tensor_transform_convert_dimension (filter, direction,
+        &in_config.info, &out_config.info);
+  }
 
-    result = gst_tensor_caps_from_config (&out_config);
-  } else {
-    if (gst_tensor_transform_read_caps (caps, &out_config)) {
-      gst_tensor_transform_convert_dimension (filter, direction,
-          &out_config.info, &in_config.info);
-    }
+  /**
+   * supposed same framerate from input configuration
+   */
+  out_config.rate_n = in_config.rate_n;
+  out_config.rate_d = in_config.rate_d;
 
-    result = gst_tensor_caps_from_config (&in_config);
-  }
+  result = gst_tensor_caps_from_config (&out_config);
 
-  if (filtercap) {
+  if (filtercap && gst_caps_get_size (filtercap) > 0) {
     GstCaps *intersection;
 
     intersection =
@@ -1242,18 +1236,12 @@ gst_tensor_transform_fixate_caps (GstBaseTransform * trans,
   silent_debug_caps (caps, "caps");
   silent_debug_caps (othercaps, "othercaps");
 
-  result = gst_tensor_transform_transform_caps (trans, direction, caps, NULL);
+  result =
+      gst_tensor_transform_transform_caps (trans, direction, caps, othercaps);
+  gst_caps_unref (othercaps);
 
-  if (othercaps) {
-    GstCaps *intersection;
-
-    intersection =
-        gst_caps_intersect_full (result, othercaps, GST_CAPS_INTERSECT_FIRST);
-
-    gst_caps_unref (othercaps);
-    gst_caps_unref (result);
-    result = intersection;
-  }
+  result = gst_caps_make_writable (result);
+  result = gst_caps_fixate (result);
 
   silent_debug_caps (result, "result");
   return result;
@@ -1267,9 +1255,8 @@ gst_tensor_transform_set_caps (GstBaseTransform * trans,
     GstCaps * incaps, GstCaps * outcaps)
 {
   GstTensorTransform *filter;
-  GstTensorConfig in_config;
-  GstTensorConfig out_config;
-  guint i;
+  GstTensorConfig in_config, out_config;
+  GstTensorConfig config;
 
   filter = GST_TENSOR_TRANSFORM_CAST (trans);
 
@@ -1289,47 +1276,24 @@ gst_tensor_transform_set_caps (GstBaseTransform * trans,
     goto error;
   }
 
-  /** check framerate */
+  /* check framerate */
   if (in_config.rate_n != out_config.rate_n
       || in_config.rate_d != out_config.rate_d) {
     silent_debug ("Framerate is not matched\n");
     goto error;
   }
 
-  /**
-   * Update in/out tensor info (dimension, type)
-   */
-  for (i = 0; i < NNS_TENSOR_RANK_LIMIT; i++) {
-    filter->fromDim[i] = in_config.info.dimension[i];
-    filter->toDim[i] = out_config.info.dimension[i];
+  /* compare type and dimension */
+  if (!gst_tensor_transform_convert_dimension (filter, GST_PAD_SINK,
+          &in_config.info, &config.info) ||
+      !gst_tensor_info_is_equal (&out_config.info, &config.info)) {
+    silent_debug ("Tensor info is not matched with given properties.\n");
+    goto error;
   }
 
-  if (filter->type == _NNS_END)
-    filter->type = in_config.info.type;
-
-  switch (filter->mode) {
-    case GTT_TRANSPOSE:
-    case GTT_ARITHMETIC:
-    case GTT_STAND:
-    case GTT_DIMCHG:
-      if (in_config.info.type != out_config.info.type
-          || filter->type != in_config.info.type) {
-        silent_debug ("Filter Type Not Matched\n");
-        goto error;
-      }
-      break;
-    case GTT_TYPECAST:
-      if (filter->type != in_config.info.type
-          || filter->data_typecast.to != out_config.info.type) {
-        silent_debug ("Filter Type Not Matched\n Input %d/%d | Output %d/%d",
-            filter->type, in_config.info.type, filter->data_typecast.to,
-            out_config.info.type);
-        goto error;
-      }
-      break;
-    default:
-      break;
-  }
+  /* set in/out tensor info */
+  filter->in_config = in_config;
+  filter->out_config = out_config;
 
   return TRUE;
 error:
@@ -1349,28 +1313,11 @@ gst_tensor_transform_transform_size (GstBaseTransform * trans,
 
   filter = GST_TENSOR_TRANSFORM_CAST (trans);
 
-  switch (filter->mode) {
-    case GTT_TRANSPOSE:
-    case GTT_ARITHMETIC:
-    case GTT_STAND:
-    case GTT_DIMCHG:
-      *othersize = size;        /* size of input = size of output if dimchg */
-      break;
-    case GTT_TYPECAST:
-    {
-      size_t srcunitsize = tensor_element_size[filter->type];
-      size_t dstunitsize = tensor_element_size[filter->data_typecast.to];
-      if (size % srcunitsize > 0)
-        return FALSE;
-      *othersize = size / srcunitsize * dstunitsize;
-      break;
-    }
-    default:
-      return FALSE;
-  }
+  /**
+   * supposed output tensor configured, then get size from output tensor info.
+   */
+  *othersize = gst_tensor_info_get_size (&filter->out_config.info);
   return TRUE;
-
-  /** @todo add verificastion procedure */
 }
 
 /**
index c4a9bea..6377512 100644 (file)
@@ -55,7 +55,7 @@ typedef enum
 {
   GTT_DIMCHG = 0,               /* Dimension Change. "dimchg" */
   GTT_TYPECAST = 1,             /* Type change. "typecast" */
-  GTT_ARITHMETIC = 2,           /* Type change. "typecast" */
+  GTT_ARITHMETIC = 2,           /* Arithmetic. "arithmetic" */
   GTT_TRANSPOSE = 3,            /* Transpose. "transpose" */
   GTT_STAND = 4,                /* Standardization. "stand" */
 
@@ -153,9 +153,8 @@ struct _GstTensorTransform
   };
   gboolean loaded; /**< TRUE if mode & option are loaded */
 
-  tensor_dim fromDim; /**< Input dimension */
-  tensor_dim toDim; /**< Output dimension */
-  tensor_type type; /**< tensor_type of input. Most transform share the same type for both input and output. However, this does not hold for typecast. */
+  GstTensorConfig in_config; /**< input tensor info */
+  GstTensorConfig out_config; /**< output tensor info */
 };
 
 /**