From 38d71ecfb0db3a160bb50ec2e9dbbcb67243be74 Mon Sep 17 00:00:00 2001 From: Vibhav Aggarwal Date: Mon, 6 Nov 2023 18:32:59 +0900 Subject: [PATCH] mv_machine_learning: reallocate parser on changing meta file [Issue type] bug fix The MetaParser needs to be reallocated when the user calls mv_facial_landmark_set_model() or mv_pose_landmark_set_model(). Change-Id: I7f21c7d36b3ffb9b869a1998bcd8ee19a365fd38 Signed-off-by: Vibhav Aggarwal --- .../include/landmark_detection_config.h | 4 +--- .../src/facial_landmark_adapter.cpp | 10 ++++---- .../src/landmark_detection_config.cpp | 27 ++++++++-------------- .../src/pose_landmark_adapter.cpp | 10 ++++---- 4 files changed, 21 insertions(+), 30 deletions(-) diff --git a/mv_machine_learning/landmark_detection/include/landmark_detection_config.h b/mv_machine_learning/landmark_detection/include/landmark_detection_config.h index c8fa1b0..4aa4d9e 100644 --- a/mv_machine_learning/landmark_detection/include/landmark_detection_config.h +++ b/mv_machine_learning/landmark_detection/include/landmark_detection_config.h @@ -32,7 +32,6 @@ class LandmarkDetectionConfig { private: std::unique_ptr _parser; - LandmarkDetectionTaskType _task_type { LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE }; std::string _defaultModelName; std::string _modelFilePath; std::string _modelMetaFilePath; @@ -49,10 +48,8 @@ public: void setUserModel(const std::string &model_file, const std::string &meta_file, const std::string &label_file); void parseConfigFile(const std::string &configFilePath); void parseMetaFile(); - void setTaskType(LandmarkDetectionTaskType task_type); void setBackendType(int backend_type); void setTargetDeviceType(int device_type); - LandmarkDetectionTaskType getTaskType() const; const std::string &getDefaultModelName() const; const std::string &getModelFilePath() const; const std::string &getLabelFilePath() const; @@ -61,6 +58,7 @@ public: double getConfidenceThreshold() const; int getBackendType() const; int getTargetDeviceType() const; + void loadMetaFile(LandmarkDetectionTaskType task_type); }; } // machine_learning diff --git a/mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp b/mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp index f912d21..60a8f85 100644 --- a/mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp +++ b/mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp @@ -33,9 +33,7 @@ template FacialLandmarkAdapter::FacialLandmarkAdap _config->parseConfigFile(_config_file_name); LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName()); - create(model_type); - _config->setTaskType(model_type); } template FacialLandmarkAdapter::~FacialLandmarkAdapter() @@ -51,6 +49,9 @@ template void FacialLandmarkAdapter::create(Landma return; } + // if model name is changed by user then reallocate the parser and reload the meta file corresponding to the model name. + _config->loadMetaFile(task_type); + if (task_type == LandmarkDetectionTaskType::FLD_TWEAK_CNN) _landmark_detection = make_unique(task_type, _config); } @@ -75,11 +76,10 @@ void FacialLandmarkAdapter::setModelInfo(const char *model_file, const cha const char *model_name) { try { - LandmarkDetectionTaskType model_type = convertToTaskType(model_name); - _config->setUserModel(model_file, meta_file, label_file); + + LandmarkDetectionTaskType model_type = convertToTaskType(model_name); create(model_type); - _config->setTaskType(model_type); } catch (const BaseException &e) { LOGW("A given model name is invalid so default task type will be used."); } diff --git a/mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp b/mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp index 90c9767..5974432 100644 --- a/mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp +++ b/mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp @@ -29,14 +29,7 @@ namespace mediavision namespace machine_learning { LandmarkDetectionConfig::LandmarkDetectionConfig() -{ - _parser = make_unique(); -} - -void LandmarkDetectionConfig::setTaskType(LandmarkDetectionTaskType task_type) -{ - _task_type = task_type; -} +{} void LandmarkDetectionConfig::setBackendType(int backend_type) { @@ -48,11 +41,6 @@ void LandmarkDetectionConfig::setTargetDeviceType(int device_type) _targetDeviceType = device_type; } -LandmarkDetectionTaskType LandmarkDetectionConfig::getTaskType() const -{ - return _task_type; -} - const std::string &LandmarkDetectionConfig::getDefaultModelName() const { return _defaultModelName; @@ -97,9 +85,9 @@ void LandmarkDetectionConfig::setUserModel(const string &model_file, const strin { if (!model_file.empty()) _modelFilePath = _modelDefaultPath + model_file; - if (meta_file.empty()) + if (!meta_file.empty()) _modelMetaFilePath = _modelDefaultPath + meta_file; - if (label_file.empty()) + if (!label_file.empty()) _modelLabelFilePath = _modelDefaultPath + label_file; } @@ -152,8 +140,6 @@ void LandmarkDetectionConfig::parseConfigFile(const std::string &configFilePath) _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath; LOGI("meta file path = %s", _modelMetaFilePath.c_str()); - _parser->load(_modelMetaFilePath); - ret = config->getDoubleAttribute(MV_LANDMARK_DETECTION_CONFIDENCE_THRESHOLD, &_confidence_threshold); if (ret != MEDIA_VISION_ERROR_NONE) LOGW("threshold value doesn't exist."); @@ -171,5 +157,12 @@ void LandmarkDetectionConfig::parseConfigFile(const std::string &configFilePath) LOGI("label file path = %s", _modelLabelFilePath.c_str()); } +void LandmarkDetectionConfig::loadMetaFile(LandmarkDetectionTaskType task_type) +{ + _parser = make_unique(); + _parser->setTaskType(static_cast(task_type)); + _parser->load(_modelMetaFilePath); +} + } } \ No newline at end of file diff --git a/mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp b/mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp index 92f361d..aee6adb 100644 --- a/mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp +++ b/mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp @@ -33,9 +33,7 @@ template PoseLandmarkAdapter::PoseLandmarkAdapter( _config->parseConfigFile(_config_file_name); LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName()); - create(model_type); - _config->setTaskType(model_type); } template PoseLandmarkAdapter::~PoseLandmarkAdapter() @@ -53,6 +51,9 @@ template void PoseLandmarkAdapter::create(Landmark return; } + // if model name is changed by user then reallocate the parser and reload the meta file corresponding to the model name. + _config->loadMetaFile(task_type); + if (task_type == LandmarkDetectionTaskType::PLD_CPM) _landmark_detection = make_unique(task_type, _config); } @@ -76,11 +77,10 @@ void PoseLandmarkAdapter::setModelInfo(const char *model_file, const char const char *model_name) { try { - LandmarkDetectionTaskType model_type = convertToTaskType(model_name); - _config->setUserModel(model_file, meta_file, label_file); + + LandmarkDetectionTaskType model_type = convertToTaskType(model_name); create(model_type); - _config->setTaskType(model_type); } catch (const BaseException &e) { LOGW("A given model name is invalid so default task type will be used."); } -- 2.7.4