mv_machine_learning: reallocate parser on changing meta file
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Mon, 6 Nov 2023 09:32:59 +0000 (18:32 +0900)
committerInki Dae <inki.dae@samsung.com>
Tue, 14 Nov 2023 07:39:43 +0000 (16:39 +0900)
[Issue type] bug fix

The MetaParser needs to be reallocated when the user
calls mv_facial_landmark_set_model() or mv_pose_landmark_set_model().

Change-Id: I7f21c7d36b3ffb9b869a1998bcd8ee19a365fd38
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
mv_machine_learning/landmark_detection/include/landmark_detection_config.h
mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp
mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp
mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp

index c8fa1b0..4aa4d9e 100644 (file)
@@ -32,7 +32,6 @@ class LandmarkDetectionConfig
 {
 private:
        std::unique_ptr<MetaParser> _parser;
-       LandmarkDetectionTaskType _task_type { LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE };
        std::string _defaultModelName;
        std::string _modelFilePath;
        std::string _modelMetaFilePath;
@@ -49,10 +48,8 @@ public:
        void setUserModel(const std::string &model_file, const std::string &meta_file, const std::string &label_file);
        void parseConfigFile(const std::string &configFilePath);
        void parseMetaFile();
-       void setTaskType(LandmarkDetectionTaskType task_type);
        void setBackendType(int backend_type);
        void setTargetDeviceType(int device_type);
-       LandmarkDetectionTaskType getTaskType() const;
        const std::string &getDefaultModelName() const;
        const std::string &getModelFilePath() const;
        const std::string &getLabelFilePath() const;
@@ -61,6 +58,7 @@ public:
        double getConfidenceThreshold() const;
        int getBackendType() const;
        int getTargetDeviceType() const;
+       void loadMetaFile(LandmarkDetectionTaskType task_type);
 };
 
 } // machine_learning
index f912d21..60a8f85 100644 (file)
@@ -33,9 +33,7 @@ template<typename T, typename V> FacialLandmarkAdapter<T, V>::FacialLandmarkAdap
        _config->parseConfigFile(_config_file_name);
 
        LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
-
        create(model_type);
-       _config->setTaskType(model_type);
 }
 
 template<typename T, typename V> FacialLandmarkAdapter<T, V>::~FacialLandmarkAdapter()
@@ -51,6 +49,9 @@ template<typename T, typename V> void FacialLandmarkAdapter<T, V>::create(Landma
                        return;
        }
 
+       // if model name is changed by user then reallocate the parser and reload the meta file corresponding to the model name.
+       _config->loadMetaFile(task_type);
+
        if (task_type == LandmarkDetectionTaskType::FLD_TWEAK_CNN)
                _landmark_detection = make_unique<FldTweakCnn>(task_type, _config);
 }
@@ -75,11 +76,10 @@ void FacialLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const cha
                                                                                           const char *model_name)
 {
        try {
-               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
-
                _config->setUserModel(model_file, meta_file, label_file);
+
+               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
                create(model_type);
-               _config->setTaskType(model_type);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }
index 90c9767..5974432 100644 (file)
@@ -29,14 +29,7 @@ namespace mediavision
 namespace machine_learning
 {
 LandmarkDetectionConfig::LandmarkDetectionConfig()
-{
-       _parser = make_unique<LandmarkDetectionParser>();
-}
-
-void LandmarkDetectionConfig::setTaskType(LandmarkDetectionTaskType task_type)
-{
-       _task_type = task_type;
-}
+{}
 
 void LandmarkDetectionConfig::setBackendType(int backend_type)
 {
@@ -48,11 +41,6 @@ void LandmarkDetectionConfig::setTargetDeviceType(int device_type)
        _targetDeviceType = device_type;
 }
 
-LandmarkDetectionTaskType LandmarkDetectionConfig::getTaskType() const
-{
-       return _task_type;
-}
-
 const std::string &LandmarkDetectionConfig::getDefaultModelName() const
 {
        return _defaultModelName;
@@ -97,9 +85,9 @@ void LandmarkDetectionConfig::setUserModel(const string &model_file, const strin
 {
        if (!model_file.empty())
                _modelFilePath = _modelDefaultPath + model_file;
-       if (meta_file.empty())
+       if (!meta_file.empty())
                _modelMetaFilePath = _modelDefaultPath + meta_file;
-       if (label_file.empty())
+       if (!label_file.empty())
                _modelLabelFilePath = _modelDefaultPath + label_file;
 }
 
@@ -152,8 +140,6 @@ void LandmarkDetectionConfig::parseConfigFile(const std::string &configFilePath)
        _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
        LOGI("meta file path = %s", _modelMetaFilePath.c_str());
 
-       _parser->load(_modelMetaFilePath);
-
        ret = config->getDoubleAttribute(MV_LANDMARK_DETECTION_CONFIDENCE_THRESHOLD, &_confidence_threshold);
        if (ret != MEDIA_VISION_ERROR_NONE)
                LOGW("threshold value doesn't exist.");
@@ -171,5 +157,12 @@ void LandmarkDetectionConfig::parseConfigFile(const std::string &configFilePath)
        LOGI("label file path = %s", _modelLabelFilePath.c_str());
 }
 
+void LandmarkDetectionConfig::loadMetaFile(LandmarkDetectionTaskType task_type)
+{
+       _parser = make_unique<LandmarkDetectionParser>();
+       _parser->setTaskType(static_cast<int>(task_type));
+       _parser->load(_modelMetaFilePath);
+}
+
 }
 }
\ No newline at end of file
index 92f361d..aee6adb 100644 (file)
@@ -33,9 +33,7 @@ template<typename T, typename V> PoseLandmarkAdapter<T, V>::PoseLandmarkAdapter(
        _config->parseConfigFile(_config_file_name);
 
        LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
-
        create(model_type);
-       _config->setTaskType(model_type);
 }
 
 template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter()
@@ -53,6 +51,9 @@ template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(Landmark
                        return;
        }
 
+       // if model name is changed by user then reallocate the parser and reload the meta file corresponding to the model name.
+       _config->loadMetaFile(task_type);
+
        if (task_type == LandmarkDetectionTaskType::PLD_CPM)
                _landmark_detection = make_unique<PldCpm>(task_type, _config);
 }
@@ -76,11 +77,10 @@ void PoseLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char
                                                                                         const char *model_name)
 {
        try {
-               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
-
                _config->setUserModel(model_file, meta_file, label_file);
+
+               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
                create(model_type);
-               _config->setTaskType(model_type);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }