[C-Api] check nnfw with file extension
authorJaeyun <jy1210.jung@samsung.com>
Wed, 26 Jun 2019 06:53:03 +0000 (15:53 +0900)
committerjaeyun-jung <39614140+jaeyun-jung@users.noreply.github.com>
Tue, 2 Jul 2019 02:33:51 +0000 (11:33 +0900)
Check given model file has valid file extension.
If the param nnfw is unknown and file ext is valid, determine fw with the ext.

Signed-off-by: Jaeyun Jung <jy1210.jung@samsung.com>
api/capi/src/nnstreamer-capi-single.c
tests/tizen_capi/unittest_tizen_capi.cpp

index 8047898..e21e574 100644 (file)
@@ -95,6 +95,7 @@ ml_single_open (ml_single_h * single, const char *model_path,
   GstCaps *caps;
   int status = ML_ERROR_NONE;
   gchar *pipeline_desc = NULL;
+  gchar *path_down;
 
   /* Validate the params */
   if (!single) {
@@ -105,12 +106,6 @@ ml_single_open (ml_single_h * single, const char *model_path,
   /* init null */
   *single = NULL;
 
-  if (!g_file_test (model_path, G_FILE_TEST_IS_REGULAR)) {
-    ml_loge ("The given param, model path [%s] is invalid.",
-        GST_STR_NULL (model_path));
-    return ML_ERROR_INVALID_PARAMETER;
-  }
-
   if (input_info &&
       ml_util_validate_tensors_info (input_info) != ML_ERROR_NONE) {
     ml_loge ("The given param, input tensor info is invalid.");
@@ -123,14 +118,65 @@ ml_single_open (ml_single_h * single, const char *model_path,
     return ML_ERROR_INVALID_PARAMETER;
   }
 
+  /* 1. Determine nnfw */
+  /* Check file extention. */
+  path_down = g_ascii_strdown (model_path, -1);
+
+  switch (nnfw) {
+    case ML_NNFW_UNKNOWN:
+      if (g_str_has_suffix (path_down, ".tflite")) {
+        ml_logi ("The given model [%s] is supposed a tensorflow-lite model.", model_path);
+        nnfw = ML_NNFW_TENSORFLOW_LITE;
+      } else if (g_str_has_suffix (path_down, ".pb")) {
+        ml_logi ("The given model [%s] is supposed a tensorflow model.", model_path);
+        nnfw = ML_NNFW_TENSORFLOW;
+      } else {
+        ml_loge ("The given model [%s] has unknown extension.", model_path);
+        status = ML_ERROR_INVALID_PARAMETER;
+      }
+      break;
+    case ML_NNFW_CUSTOM_FILTER:
+      if (!g_str_has_suffix (path_down, ".so")) {
+        ml_loge ("The given model [%s] has invalid extension.", model_path);
+        status = ML_ERROR_INVALID_PARAMETER;
+      }
+      break;
+    case ML_NNFW_TENSORFLOW_LITE:
+      if (!g_str_has_suffix (path_down, ".tflite")) {
+        ml_loge ("The given model [%s] has invalid extension.", model_path);
+        status = ML_ERROR_INVALID_PARAMETER;
+      }
+      break;
+    case ML_NNFW_TENSORFLOW:
+      if (!g_str_has_suffix (path_down, ".pb")) {
+        ml_loge ("The given model [%s] has invalid extension.", model_path);
+        status = ML_ERROR_INVALID_PARAMETER;
+      }
+      break;
+    default:
+      break;
+  }
+
+  g_free (path_down);
+  if (status != ML_ERROR_NONE)
+    return status;
+
+  if (!g_file_test (model_path, G_FILE_TEST_IS_REGULAR)) {
+    ml_loge ("The given param, model path [%s] is invalid.",
+        GST_STR_NULL (model_path));
+    return ML_ERROR_INVALID_PARAMETER;
+  }
+
+  /* 2. Determine hw */
+  /** @todo Now the param hw is ignored. (Supposed CPU only) Support others later. */
   status = ml_util_check_nnfw (nnfw, hw);
   if (status < 0) {
     ml_loge ("The given nnfw is not available.");
     return status;
   }
 
-  /* 1. Determine nnfw */
-  /** @todo Check nnfw with file extention. */
+  /* 3. Construct a pipeline */
+  /* Set the pipeline desc with nnfw. */
   switch (nnfw) {
     case ML_NNFW_CUSTOM_FILTER:
       pipeline_desc =
@@ -139,22 +185,13 @@ ml_single_open (ml_single_h * single, const char *model_path,
           model_path);
       break;
     case ML_NNFW_TENSORFLOW_LITE:
-      if (!g_str_has_suffix (model_path, ".tflite")) {
-        ml_loge ("The given model file [%s] has invalid extension.", model_path);
-        return ML_ERROR_INVALID_PARAMETER;
-      }
-
+      /* We can get the tensor meta from tf-lite model. */
       pipeline_desc =
           g_strdup_printf
           ("appsrc name=srcx ! tensor_filter name=filterx framework=tensorflow-lite model=%s ! appsink name=sinkx sync=false",
           model_path);
       break;
     case ML_NNFW_TENSORFLOW:
-      if (!g_str_has_suffix (model_path, ".pb")) {
-        ml_loge ("The given model file [%s] has invalid extension.", model_path);
-        return ML_ERROR_INVALID_PARAMETER;
-      }
-
       if (input_info && output_info) {
         GstTensorsInfo in_info, out_info;
         gchar *str_dim, *str_type, *str_name;
@@ -203,10 +240,6 @@ ml_single_open (ml_single_h * single, const char *model_path,
       return ML_ERROR_NOT_SUPPORTED;
   }
 
-  /* 2. Determine hw */
-  /** @todo Now the param hw is ignored. (Supposed CPU only) Support others later. */
-
-  /* 3. Construct a pipeline */
   status = ml_pipeline_construct (pipeline_desc, &pipe);
   g_free (pipeline_desc);
   if (status != ML_ERROR_NONE) {
index 4004212..280bf57 100644 (file)
@@ -1497,15 +1497,23 @@ TEST (nnstreamer_capi_singleshot, failure_01)
   out_info.info[0].dimension[2] = 1;
   out_info.info[0].dimension[3] = 1;
 
-  /* unknown fw type */
+  /* invalid file extension */
   status = ml_single_open (&single, test_model, &in_info, &out_info,
-      ML_NNFW_UNKNOWN, ML_NNFW_HW_DO_NOT_CARE);
-  EXPECT_EQ (status, ML_ERROR_NOT_SUPPORTED);
+      ML_NNFW_TENSORFLOW, ML_NNFW_HW_DO_NOT_CARE);
+  EXPECT_EQ (status, ML_ERROR_INVALID_PARAMETER);
 
   /* invalid handle */
   status = ml_single_close (single);
   EXPECT_EQ (status, ML_ERROR_INVALID_PARAMETER);
 
+  /* Successfully opened unknown fw type (tf-lite) */
+  status = ml_single_open (&single, test_model, &in_info, &out_info,
+      ML_NNFW_UNKNOWN, ML_NNFW_HW_DO_NOT_CARE);
+  EXPECT_EQ (status, ML_ERROR_NONE);
+
+  status = ml_single_close (single);
+  EXPECT_EQ (status, ML_ERROR_NONE);
+
   g_free (test_model);
 }
 #endif /* ENABLE_TENSORFLOW_LITE */