LOGE("%d: %s: %s", i, (iter->second).first.c_str(), (iter->second).second ? "TRUE" : "FALSE");
}
+ mModelFormats.insert(std::make_pair<std::string, int>("caffemodel", INFERENCE_MODEL_CAFFE));
+ mModelFormats.insert(std::make_pair<std::string, int>("pb", INFERENCE_MODEL_TF));
+ mModelFormats.insert(std::make_pair<std::string, int>("tflite", INFERENCE_MODEL_TFLITE));
+ mModelFormats.insert(std::make_pair<std::string, int>("t7", INFERENCE_MODEL_TORCH));
+ mModelFormats.insert(std::make_pair<std::string, int>("weights", INFERENCE_MODEL_DARKNET));
+ mModelFormats.insert(std::make_pair<std::string, int>("bin", INFERENCE_MODEL_DLDT));
+ mModelFormats.insert(std::make_pair<std::string, int>("onnx", INFERENCE_MODEL_ONNX));
+
LOGI("LEAVE");
}
Inference::~Inference()
{
+ mModelFormats.clear();
+
// Release backend engine.
mBackend->UnbindBackend();
{
LOGI("ENTER");
- // Add model files to load.
- // TODO. model file and its corresponding label file should be added by
- // user request.
- std::vector<std::string> models;
- models.push_back(mConfig.mWeightFilePath);
- models.push_back(mConfig.mUserFilePath);
+ // Check if model file is valid or not.
+ std::string ext_str = mConfig.mWeightFilePath.substr(mConfig.mWeightFilePath.find_last_of(".") + 1);
+ std::map<std::string, int>::iterator key = mModelFormats.find(ext_str);
+ if (key == mModelFormats.end()) {
+ LOGE("Invalid model file format.(ext = %s)", ext_str.c_str());
+ return MEDIA_VISION_ERROR_INVALID_PARAMETER;
+ }
+
+ LOGI("%s model file has been detected.", ext_str.c_str());
+
+ std::vector<std::string> models;
+
+ // Push model file information to models vector properly according to detected model format.
+ switch (key->second) {
+ case INFERENCE_MODEL_CAFFE:
+ case INFERENCE_MODEL_TF:
+ case INFERENCE_MODEL_DARKNET:
+ case INFERENCE_MODEL_DLDT:
+ case INFERENCE_MODEL_ONNX:
+ models.push_back(mConfig.mWeightFilePath);
+ models.push_back(mConfig.mConfigFilePath);
+ break;
+ case INFERENCE_MODEL_TFLITE:
+ case INFERENCE_MODEL_TORCH:
+ models.push_back(mConfig.mWeightFilePath);
+ break;
+ default:
+ break;
+ }
+
+ models.push_back(mConfig.mUserFilePath);
// Request model loading to backend engine.
- int ret = mBackend->Load(models, 1);
+ int ret = mBackend->Load(models, (inference_model_format_e)key->second);
if (ret != INFERENCE_ENGINE_ERROR_NONE) {
delete mBackend;
LOGE("Fail to load model");
mCanRun = false;
- return ConvertEngineErrorToVisionError(ret);
+ goto out;
}
mCanRun = true;
+out:
+ std::vector<std::string>().swap(models);
LOGI("LEAVE");
- return MEDIA_VISION_ERROR_NONE;
+ return ConvertEngineErrorToVisionError(ret);
}
int Inference::Run(mv_source_h mvSource, mv_rectangle_s *roi)