mv_inference: Detect model file fomat and notify a backend the format
authorInki Dae <inki.dae@samsung.com>
Tue, 11 Feb 2020 05:54:32 +0000 (14:54 +0900)
committerInki Dae <inki.dae@samsung.com>
Tue, 14 Apr 2020 00:40:31 +0000 (09:40 +0900)
This patch detects a given model file, and contains proper model file
information a given vector. And then it notifies the model file format
a backend engine with the vector so that the backend engine handles
the model file correctly.

Change-Id: Id1fa75ee5553e099ff3d2800912aa41fb5f3ef90
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_inference/inference/include/Inference.h
mv_inference/inference/src/Inference.cpp

index 3c82d1fce9a3a4c6aed9f50d55fcf267ec214133..223ad155c30190e80fe299bb4b01d7451696199a 100755 (executable)
@@ -270,6 +270,7 @@ private:
        InferenceEngineVision * mBackend;
 
        std::map<int, std::pair<std::string, bool>> mSupportedInferenceBackend;
+       std::map<std::string, int> mModelFormats;
 
 private:
        void CheckSupportedInferenceBackend();
index e201cac406f752a03883d6ebc34842684dbcff7e..ee72a6d66a42ffff97408534f243648c082519a3 100755 (executable)
@@ -69,11 +69,21 @@ Inference::Inference() :
                LOGE("%d: %s: %s", i, (iter->second).first.c_str(), (iter->second).second ? "TRUE" : "FALSE");
        }
 
+       mModelFormats.insert(std::make_pair<std::string, int>("caffemodel", INFERENCE_MODEL_CAFFE));
+       mModelFormats.insert(std::make_pair<std::string, int>("pb", INFERENCE_MODEL_TF));
+       mModelFormats.insert(std::make_pair<std::string, int>("tflite", INFERENCE_MODEL_TFLITE));
+       mModelFormats.insert(std::make_pair<std::string, int>("t7", INFERENCE_MODEL_TORCH));
+       mModelFormats.insert(std::make_pair<std::string, int>("weights", INFERENCE_MODEL_DARKNET));
+       mModelFormats.insert(std::make_pair<std::string, int>("bin", INFERENCE_MODEL_DLDT));
+       mModelFormats.insert(std::make_pair<std::string, int>("onnx", INFERENCE_MODEL_ONNX));
+
        LOGI("LEAVE");
 }
 
 Inference::~Inference()
 {
+       mModelFormats.clear();
+
        // Release backend engine.
        mBackend->UnbindBackend();
 
@@ -338,27 +348,54 @@ int Inference::Load(void)
 {
        LOGI("ENTER");
 
-       // Add model files to load.
-    // TODO. model file and its corresponding label file should be added by
-    // user request.
-    std::vector<std::string> models;
-    models.push_back(mConfig.mWeightFilePath);
-    models.push_back(mConfig.mUserFilePath);
+       // Check if model file is valid or not.
+       std::string ext_str = mConfig.mWeightFilePath.substr(mConfig.mWeightFilePath.find_last_of(".") + 1);
+       std::map<std::string, int>::iterator key = mModelFormats.find(ext_str);
+       if (key == mModelFormats.end()) {
+               LOGE("Invalid model file format.(ext = %s)", ext_str.c_str());
+               return MEDIA_VISION_ERROR_INVALID_PARAMETER;
+       }
+
+       LOGI("%s model file has been detected.", ext_str.c_str());
+
+       std::vector<std::string> models;
+
+       // Push model file information to models vector properly according to detected model format.
+       switch (key->second) {
+       case INFERENCE_MODEL_CAFFE:
+       case INFERENCE_MODEL_TF:
+       case INFERENCE_MODEL_DARKNET:
+       case INFERENCE_MODEL_DLDT:
+       case INFERENCE_MODEL_ONNX:
+               models.push_back(mConfig.mWeightFilePath);
+               models.push_back(mConfig.mConfigFilePath);
+               break;
+       case INFERENCE_MODEL_TFLITE:
+       case INFERENCE_MODEL_TORCH:
+               models.push_back(mConfig.mWeightFilePath);
+               break;
+       default:
+               break;
+       }
+
+       models.push_back(mConfig.mUserFilePath);
 
     // Request model loading to backend engine.
-    int ret = mBackend->Load(models, 1);
+    int ret = mBackend->Load(models, (inference_model_format_e)key->second);
        if (ret != INFERENCE_ENGINE_ERROR_NONE) {
                delete mBackend;
                LOGE("Fail to load model");
                mCanRun = false;
-               return ConvertEngineErrorToVisionError(ret);
+               goto out;
        }
 
        mCanRun = true;
+out:
+       std::vector<std::string>().swap(models);
 
        LOGI("LEAVE");
 
-       return MEDIA_VISION_ERROR_NONE;
+       return ConvertEngineErrorToVisionError(ret);
 }
 
 int Inference::Run(mv_source_h mvSource, mv_rectangle_s *roi)