[Filter] add the model path parser
authorHyoung Joo Ahn <hello.ahn@samsung.com>
Mon, 27 May 2019 07:37:15 +0000 (16:37 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 6 Jun 2019 04:59:41 +0000 (00:59 -0400)
Since caffe2 require 2 model files(init, prediction) the model path option should handle the complicated input string

Signed-off-by: Hyoung Joo Ahn <hello.ahn@samsung.com>
gst/nnstreamer/nnstreamer_plugin_api_filter.h
gst/nnstreamer/tensor_filter/tensor_filter.c

index 78f4689..ab8969f 100644 (file)
@@ -37,6 +37,7 @@ typedef struct _GstTensorFilterProperties
   const char *fwname; /**< The name of NN Framework */
   int fw_opened; /**< TRUE IF open() is called or tried. Use int instead of gboolean because this is refered by custom plugins. */
   const char *model_file; /**< Filepath to the model file (as an argument for NNFW). char instead of gchar for non-glib custom plugins */
+  const char *model_file_sub; /**< Filepath to the init model file (as an argument for NNFW). Some frameworks need this file to initialize the graph(caffe, caffe2) */
 
   int input_configured; /**< TRUE if input tensor is configured. Use int instead of gboolean because this is refered by custom plugins. */
   GstTensorsInfo input_meta; /**< configured input tensor info */
index dc67cc5..a1b12d5 100644 (file)
@@ -308,8 +308,8 @@ gst_tensor_filter_class_init (GstTensorFilterClass * klass)
           G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
   g_object_class_install_property (gobject_class, PROP_MODEL,
       g_param_spec_string ("model", "Model filepath",
-          "File path to the model file", "",
-          G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
+          "File path to the model file. Separated with \
+          ',' in case of multiple model files(like caffe2)", "", G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS));
   g_object_class_install_property (gobject_class, PROP_INPUT,
       g_param_spec_string ("input", "Input dimension",
           "Input tensor dimension from inner array, up to 4 dimensions ?", "",
@@ -435,6 +435,84 @@ gst_tensor_filter_finalize (GObject * object)
 }
 
 /**
+ * @brief Parse the string of model
+ * @param info tensors info structure
+ * @param model_file the prediction model path
+ * @param model_file_sub the initialize model path
+ * @return number of parsed model path
+ * @todo Create a struct list to save multiple model files with key, value pair
+ */
+static guint
+gst_tensors_parse_modelpaths_string (GstTensorFilterProperties * prop,
+    const gchar * model_files)
+{
+  gchar **models;
+  gchar **model_0;
+  gchar **model_1;
+  guint num_models = 0;
+  guint num_model_0 = 0;
+  guint num_model_1 = 0;
+
+  g_return_val_if_fail (prop != NULL, 0);
+
+  if (model_files) {
+    models = g_strsplit_set (model_files, ",", -1);
+    num_models = g_strv_length (models);
+
+    if (num_models == 1) {
+      prop->model_file = g_strdup (models[0]);
+    } else if (num_models == 2) {
+      model_0 = g_strsplit_set (models[0], "=", -1);
+      model_1 = g_strsplit_set (models[1], "=", -1);
+
+      num_model_0 = g_strv_length (model_0);
+      num_model_1 = g_strv_length (model_1);
+
+      if (num_model_0 == 1 && num_model_1 == 1) {
+        prop->model_file_sub = g_strdup (model_0[0]);
+        prop->model_file = g_strdup (model_1[0]);
+      } else if (g_ascii_strncasecmp (model_0[0], "init", 4) == 0 ||
+          g_ascii_strncasecmp (model_0[0], "Init", 4) == 0) {
+        prop->model_file_sub = g_strdup (model_0[1]);
+
+        if (num_model_1 == 2)
+          prop->model_file = g_strdup (model_1[1]);
+        else
+          prop->model_file = g_strdup (model_1[0]);
+      } else if (g_ascii_strncasecmp (model_0[0], "pred", 4) == 0 ||
+          g_ascii_strncasecmp (model_0[0], "Pred", 4) == 0) {
+        prop->model_file = g_strdup (model_0[1]);
+
+        if (num_model_1 == 2)
+          prop->model_file_sub = g_strdup (model_1[1]);
+        else
+          prop->model_file_sub = g_strdup (model_1[0]);
+      } else if (g_ascii_strncasecmp (model_1[0], "init", 4) == 0 ||
+          g_ascii_strncasecmp (model_1[0], "Init", 4) == 0) {
+        prop->model_file_sub = g_strdup (model_1[1]);
+
+        if (num_model_0 == 2)
+          prop->model_file = g_strdup (model_0[1]);
+        else
+          prop->model_file = g_strdup (model_0[0]);
+      } else if (g_ascii_strncasecmp (model_1[0], "pred", 4) == 0 ||
+          g_ascii_strncasecmp (model_1[0], "Pred", 4) == 0) {
+        prop->model_file = g_strdup (model_1[1]);
+
+        if (num_model_0 == 2)
+          prop->model_file_sub = g_strdup (model_0[1]);
+        else
+          prop->model_file_sub = g_strdup (model_0[0]);
+      }
+      g_strfreev (model_0);
+      g_strfreev (model_1);
+    }
+    g_strfreev (models);
+  }
+  return num_models;
+}
+
+/**
  * @brief Calculate output buffer size.
  * @param self "this" pointer
  * @param index index of output tensors (if index < 0, the size of all output tensors will be returned.)
@@ -512,7 +590,8 @@ gst_tensor_filter_set_property (GObject * object, guint prop_id,
     }
     case PROP_MODEL:
     {
-      const gchar *model_file = g_value_get_string (value);
+      const gchar *model_files = g_value_get_string (value);
+      guint model_num;
 
       if (prop->model_file) {
         gst_tensor_filter_close_fw (self);
@@ -520,14 +599,37 @@ gst_tensor_filter_set_property (GObject * object, guint prop_id,
         prop->model_file = NULL;
       }
 
-      /* Once configures, it cannot be changed in runtime */
-      g_assert (model_file);
+      if (prop->model_file_sub) {
+        gst_tensor_filter_close_fw (self);
+        g_free_const (prop->model_file_sub);
+        prop->model_file_sub = NULL;
+      }
 
-      silent_debug ("Model = %s\n", model_file);
-      if (!g_file_test (model_file, G_FILE_TEST_IS_REGULAR)) {
-        GST_ERROR_OBJECT (self, "Cannot find the model file: %s\n", model_file);
+      /* Once configures, it cannot be changed in runtime */
+      /** @todo by using `gst_element_get_state()`, reject configurations in RUNNING or other states */
+      g_assert (model_files);
+      model_num = gst_tensors_parse_modelpaths_string (prop, model_files);
+      if (model_num == 1) {
+        silent_debug ("Model = %s\n", prop->model_file);
+        if (!g_file_test (prop->model_file, G_FILE_TEST_IS_REGULAR))
+          GST_ERROR_OBJECT (self, "Cannot find the model file: %s\n",
+              prop->model_file);
+      } else if (model_num == 2) {
+        silent_debug ("Init Model = %s\n", prop->model_file_sub);
+        silent_debug ("Pred Model = %s\n", prop->model_file);
+        if (!g_file_test (prop->model_file_sub, G_FILE_TEST_IS_REGULAR))
+          GST_ERROR_OBJECT (self, "Cannot find the init model file: %s\n",
+              prop->model_file_sub);
+        if (!g_file_test (prop->model_file, G_FILE_TEST_IS_REGULAR))
+          GST_ERROR_OBJECT (self, "Cannot find the pred model file: %s\n",
+              prop->model_file);
+      } else if (model_num > 2) {
+        /** @todo if the new NN framework requires more than 2 model files, this area will be implemented */
+        GST_ERROR_OBJECT (self,
+            "There is no NN framework that requires model files more than 2. Current Input model files are :%d\n",
+            model_num);
       } else {
-        prop->model_file = g_strdup (model_file);
+        GST_ERROR_OBJECT (self, "Set model file path first\n");
       }
       break;
     }
@@ -1021,7 +1123,7 @@ gst_tensor_filter_compare_tensors (GstTensorsInfo * info1,
 
     line =
         g_strdup_printf ("%2d : %s | %s %s\n", i, left, right,
-        !g_strcmp0 (left, right) ? "" : "FAILED");
+        !strcmp (left, right) ? "" : "FAILED");
     if (left && left[0] != '\0')
       g_free (left);
     if (right && right[0] != '\0')