Consider various NN model formats
authorInki Dae <inki.dae@samsung.com>
Tue, 18 Feb 2020 04:12:36 +0000 (13:12 +0900)
committerInki Dae <inki.dae@samsung.com>
Tue, 18 Feb 2020 04:12:36 +0000 (13:12 +0900)
As of now, only TFLite model is supported, and other will be
supported layer.

Change-Id: Ie255875032bfc6051da7a14913564bb9548c69b7
Signed-off-by: Inki Dae <inki.dae@samsung.com>
src/inference_engine_armnn.cpp
src/inference_engine_armnn_private.h

index 021346da99077cd5d668bb9555802762da9ff234..5fee20d3c16354645c88be3b7dcc14109c35a1bf 100644 (file)
@@ -204,15 +204,45 @@ int InferenceARMNN::CreateTfLiteNetwork(std::string model_path)
     return INFERENCE_ENGINE_ERROR_NONE;
 }
 
-int InferenceARMNN::CreateNetwork(std::string model_path)
+int InferenceARMNN::CreateNetwork(std::vector<std::string> model_paths, inference_model_format_e model_format)
 {
     LOGI("ENTER");
 
-    // TODO. check extension name of model file and call a proper ARMNN parser.
+    // Make sure to check if a given model format is supported or not.
+    if (model_format != INFERENCE_MODEL_CAFFE &&
+        model_format != INFERENCE_MODEL_TF &&
+        model_format != INFERENCE_MODEL_TFLITE &&
+        model_format != INFERENCE_MODEL_ONNX) {
+            LOGE("Invalid model format.");
+            return INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT;
+    }
+
+    int ret = INFERENCE_ENGINE_ERROR_NONE;
+
+    switch ((int)model_format) {
+    case INFERENCE_MODEL_CAFFE:
+    case INFERENCE_MODEL_TF:
+    case INFERENCE_MODEL_ONNX:
+        ret = INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT;
+        // TODO. Call a proper parser.
+        break;
+    case INFERENCE_MODEL_TFLITE:
+        std::string model_path = model_paths[0];
+        if (access(model_path.c_str(), F_OK)) {
+            LOGE("modelFilePath in [%s] ", model_path.c_str());
+            ret = INFERENCE_ENGINE_ERROR_INVALID_PATH;
+            break;
+        }
+
+        LOGI("It will try to load %s model file", model_path.c_str());
+        return CreateTfLiteNetwork(model_path);
+    }
+
+    LOGE("Model format not supported.");
 
     LOGI("LEAVE");
 
-    return CreateTfLiteNetwork(model_path);
+    return ret;
 }
 
 int InferenceARMNN::Load(std::vector<std::string> model_paths, inference_model_format_e model_format)
@@ -221,15 +251,7 @@ int InferenceARMNN::Load(std::vector<std::string> model_paths, inference_model_f
 
     int ret = INFERENCE_ENGINE_ERROR_NONE;
 
-    std::string model_path = model_paths.front();
-    if (access(model_path.c_str(), F_OK)) {
-        LOGE("modelFilePath in [%s] ", model_path.c_str());
-        return INFERENCE_ENGINE_ERROR_INVALID_PATH;
-    }
-
-    LOGI("Model File = %s", model_path.c_str());
-
-    ret = CreateNetwork(model_path);
+    ret = CreateNetwork(model_paths, model_format);
     if (ret != INFERENCE_ENGINE_ERROR_NONE)
         return ret;
 
index 2fb13d812c620ae6ae2f4e89d7afb6bfc44f52a1..b251cf9adaf3ac52da3058c75f599deec8514b18 100644 (file)
@@ -72,7 +72,7 @@ public:
 
 private:
     int CreateTfLiteNetwork(std::string model_path);
-    int CreateNetwork(std::string model_path);
+    int CreateNetwork(std::vector<std::string> model_paths, inference_model_format_e model_format);
     void *AllocateTensorBuffer(armnn::DataType type, int tensor_size);
     inference_tensor_data_type_e ConvertDataType(armnn::DataType type);
     void ReleaseTensorBuffer(armnn::DataType type, void *tensor_buffer);