return INFERENCE_ENGINE_ERROR_NONE;
}
-int InferenceARMNN::CreateNetwork(std::string model_path)
+int InferenceARMNN::CreateNetwork(std::vector<std::string> model_paths, inference_model_format_e model_format)
{
LOGI("ENTER");
- // TODO. check extension name of model file and call a proper ARMNN parser.
+ // Make sure to check if a given model format is supported or not.
+ if (model_format != INFERENCE_MODEL_CAFFE &&
+ model_format != INFERENCE_MODEL_TF &&
+ model_format != INFERENCE_MODEL_TFLITE &&
+ model_format != INFERENCE_MODEL_ONNX) {
+ LOGE("Invalid model format.");
+ return INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT;
+ }
+
+ int ret = INFERENCE_ENGINE_ERROR_NONE;
+
+ switch ((int)model_format) {
+ case INFERENCE_MODEL_CAFFE:
+ case INFERENCE_MODEL_TF:
+ case INFERENCE_MODEL_ONNX:
+ ret = INFERENCE_ENGINE_ERROR_NOT_SUPPORTED_FORMAT;
+ // TODO. Call a proper parser.
+ break;
+ case INFERENCE_MODEL_TFLITE:
+ std::string model_path = model_paths[0];
+ if (access(model_path.c_str(), F_OK)) {
+ LOGE("modelFilePath in [%s] ", model_path.c_str());
+ ret = INFERENCE_ENGINE_ERROR_INVALID_PATH;
+ break;
+ }
+
+ LOGI("It will try to load %s model file", model_path.c_str());
+ return CreateTfLiteNetwork(model_path);
+ }
+
+ LOGE("Model format not supported.");
LOGI("LEAVE");
- return CreateTfLiteNetwork(model_path);
+ return ret;
}
int InferenceARMNN::Load(std::vector<std::string> model_paths, inference_model_format_e model_format)
int ret = INFERENCE_ENGINE_ERROR_NONE;
- std::string model_path = model_paths.front();
- if (access(model_path.c_str(), F_OK)) {
- LOGE("modelFilePath in [%s] ", model_path.c_str());
- return INFERENCE_ENGINE_ERROR_INVALID_PATH;
- }
-
- LOGI("Model File = %s", model_path.c_str());
-
- ret = CreateNetwork(model_path);
+ ret = CreateNetwork(model_paths, model_format);
if (ret != INFERENCE_ENGINE_ERROR_NONE)
return ret;