mv_machine_learning: use DEFAULT_MODEL_NAME for landmark detection task group
[platform/core/api/mediavision.git] / mv_machine_learning / landmark_detection / src / facial_landmark_adapter.cpp
index 58677e9..5f649d7 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "machine_learning_exception.h"
 #include "facial_landmark_adapter.h"
+#include "mv_landmark_detection_config.h"
 
 using namespace std;
 using namespace MediaVision::Common;
@@ -28,10 +29,15 @@ namespace machine_learning
 {
 template<typename T, typename V> FacialLandmarkAdapter<T, V>::FacialLandmarkAdapter() : _source()
 {
-       // In default, Mobilenet v1 ssd model will be used.
-       // If other model is set by user then strategy pattern will be used
-       // to create its corresponding concrete class by calling create().
-       _landmark_detection = make_unique<FldTweakCnn>(LandmarkDetectionTaskType::FLD_TWEAK_CNN);
+       auto config = make_unique<EngineConfig>(MV_CONFIG_PATH + _config_file_name);
+
+       string defaultModelName;
+
+       int ret = config->getStringAttribute(MV_LANDMARK_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get default model name.");
+
+       create(convertToTaskType(defaultModelName));
 }
 
 template<typename T, typename V> FacialLandmarkAdapter<T, V>::~FacialLandmarkAdapter()
@@ -41,40 +47,47 @@ template<typename T, typename V> FacialLandmarkAdapter<T, V>::~FacialLandmarkAda
 
 template<typename T, typename V> void FacialLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
 {
-       // If default task type is same as a given one then skip.
-       if (_landmark_detection->getTaskType() == task_type)
-               return;
-
-       _landmark_detection.reset();
+       if (_landmark_detection) {
+               // If current task type is same as a given one then skip.
+               if (_landmark_detection->getTaskType() == task_type)
+                       return;
+       }
 
        if (task_type == LandmarkDetectionTaskType::FLD_TWEAK_CNN)
                _landmark_detection = make_unique<FldTweakCnn>(task_type);
-       // TODO.
 }
 
 template<typename T, typename V>
-void FacialLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
-                                                                                          const char *model_name)
+LandmarkDetectionTaskType FacialLandmarkAdapter<T, V>::convertToTaskType(string model_name)
 {
-       string model_name_str(model_name);
+       if (model_name.empty())
+               throw InvalidParameter("model name is empty.");
 
-       if (!model_name_str.empty()) {
-               transform(model_name_str.begin(), model_name_str.end(), model_name_str.begin(), ::toupper);
+       transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
 
-               LandmarkDetectionTaskType task_type = LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE;
+       if (model_name == "FLD_TWEAK_CNN")
+               return LandmarkDetectionTaskType::FLD_TWEAK_CNN;
+       // TODO.
 
-               if (model_name_str == string("FLD_TWEAK_CNN"))
-                       task_type = LandmarkDetectionTaskType::FLD_TWEAK_CNN;
-               // TODO.
-               else
-                       throw InvalidParameter("Invalid landmark detection model name.");
+       throw InvalidParameter("Invalid facial detection model name.");
+}
 
-               create(task_type);
+template<typename T, typename V>
+void FacialLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
+                                                                                          const char *model_name)
+{
+       try {
+               create(convertToTaskType(model_name));
+       } catch (const BaseException &e) {
+               LOGW("A given model name is invalid so default task type will be used.");
        }
 
-       _model_file = string(model_file);
-       _meta_file = string(meta_file);
-       _label_file = string(label_file);
+       if (model_file)
+               _model_file = model_file;
+       if (meta_file)
+               _meta_file = meta_file;
+       if (label_file)
+               _label_file = label_file;
 
        if (_model_file.empty() && _meta_file.empty()) {
                LOGW("Given model info is invalid so default model info will be used instead.");
@@ -92,7 +105,7 @@ void FacialLandmarkAdapter<T, V>::setEngineInfo(const char *engine_type, const c
 
 template<typename T, typename V> void FacialLandmarkAdapter<T, V>::configure()
 {
-       _landmark_detection->configure("facial_landmark.json");
+       _landmark_detection->configure(_config_file_name);
 }
 
 template<typename T, typename V> void FacialLandmarkAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)