mv_machine_learning: drop parsing dependency from landmark detection
authorInki Dae <inki.dae@samsung.com>
Tue, 31 Oct 2023 03:45:56 +0000 (12:45 +0900)
committerInki Dae <inki.dae@samsung.com>
Tue, 14 Nov 2023 07:38:57 +0000 (16:38 +0900)
[Issue type] : code refactoring

Drop the configuration and meta file parsing dependency from LandmarkDetection
class.

Until now, the concrete class of each task group got the task group
configuration information from its own configuration file, and also
included a MetaPaser class object to get the tensor information,
which is corresponding to a given model file.

However, these dependencies led code smell, divergent change[1] even though
the concrete class has no any dependency from parsing the configuration and
meta files - needed only information after parsed.

As a first refactoring work, this patch extracts parsing portion from
LandmarkDetection class and introduces as a new class, LandmarkDetectionConfig
class.

With this, adapter classes of the landmark detection task group will parse
the configuration and meta files before creating LandmarkDetection class.
And then it will create LandmarkDetection class with needed information.
As a result, we could manage the LandmarkDetection class without any
dependency on parsing work.

[1] https://refactoring.guru/smells/divergent-change

Change-Id: I29f2d684e2b6e698dcf36a8294c608e90dda67c0
Signed-off-by: Inki Dae <inki.dae@samsung.com>
13 files changed:
mv_machine_learning/landmark_detection/include/facial_landmark_adapter.h
mv_machine_learning/landmark_detection/include/fld_tweak_cnn.h
mv_machine_learning/landmark_detection/include/landmark_detection.h
mv_machine_learning/landmark_detection/include/landmark_detection_config.h [new file with mode: 0644]
mv_machine_learning/landmark_detection/include/landmark_detection_type.h
mv_machine_learning/landmark_detection/include/pld_cpm.h
mv_machine_learning/landmark_detection/include/pose_landmark_adapter.h
mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp
mv_machine_learning/landmark_detection/src/fld_tweak_cnn.cpp
mv_machine_learning/landmark_detection/src/landmark_detection.cpp
mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp [new file with mode: 0644]
mv_machine_learning/landmark_detection/src/pld_cpm.cpp
mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp

index 5061e75..967bf9f 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "EngineConfig.h"
 #include "itask.h"
+#include "landmark_detection_config.h"
 #include "fld_tweak_cnn.h"
 
 namespace mediavision
@@ -31,11 +32,8 @@ template<typename T, typename V> class FacialLandmarkAdapter : public mediavisio
 {
 private:
        std::unique_ptr<LandmarkDetection> _landmark_detection;
+       std::shared_ptr<LandmarkDetectionConfig> _config;
        T _source;
-       std::string _model_name;
-       std::string _model_file;
-       std::string _meta_file;
-       std::string _label_file;
        const std::string _config_file_name = "facial_landmark.json";
 
        void create(LandmarkDetectionTaskType task_type);
index 6f60474..9587633 100644 (file)
@@ -35,7 +35,7 @@ private:
        LandmarkDetectionResult _result;
 
 public:
-       FldTweakCnn(LandmarkDetectionTaskType task_type);
+       FldTweakCnn(LandmarkDetectionTaskType task_type, std::shared_ptr<LandmarkDetectionConfig> config);
        ~FldTweakCnn();
 
        LandmarkDetectionResult &result() override;
index 1935dde..8db218e 100644 (file)
@@ -25,7 +25,9 @@
 #include "inference_engine_common_impl.h"
 #include "Inference.h"
 #include "landmark_detection_type.h"
+#include "MetaParser.h"
 #include "LandmarkDetectionParser.h"
+#include "landmark_detection_config.h"
 #include "machine_learning_preprocess.h"
 #include "async_manager.h"
 
@@ -52,43 +54,36 @@ private:
 
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
-       std::unique_ptr<MediaVision::Common::EngineConfig> _config;
-       std::unique_ptr<MetaParser> _parser;
+       std::shared_ptr<LandmarkDetectionConfig> _config;
        std::vector<std::string> _labels;
        std::vector<std::string> _valid_backends;
        std::vector<std::string> _valid_devices;
        Preprocess _preprocess;
-       std::string _modelFilePath;
-       std::string _modelMetaFilePath;
-       std::string _modelDefaultPath;
-       std::string _modelLabelFilePath;
-       int _backendType {};
-       int _targetDeviceType {};
-       double _confidence_threshold {};
 
        void getOutputNames(std::vector<std::string> &names);
        void getOutputTensor(std::string target_name, std::vector<float> &tensor);
-       void parseMetaFile(const std::string &meta_file_name);
-
        template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
        virtual LandmarkDetectionResult &result() = 0;
 
 public:
-       LandmarkDetection(LandmarkDetectionTaskType task_type);
+       LandmarkDetection(LandmarkDetectionTaskType task_type, std::shared_ptr<LandmarkDetectionConfig> config);
        virtual ~LandmarkDetection() = default;
        void preDestroy();
        LandmarkDetectionTaskType getTaskType();
        void setUserModel(std::string model_file, std::string meta_file, std::string label_file);
-       void setEngineInfo(std::string engine_type, std::string device_type);
+       void setEngineInfo(std::string engine_type_name, std::string device_type_name);
        void getNumberOfEngines(unsigned int *number_of_engines);
        void getEngineType(unsigned int engine_index, char **engine_type);
        void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices);
        void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type);
-       void configure(const std::string &configFile);
+       void configure();
        void prepare();
        void perform(mv_source_h &mv_src);
        void performAsync(LandmarkDetectionInput &input);
        LandmarkDetectionResult &getOutput();
+       void setInputMetaMap(const MetaMap &input_meta_map);
+       void setOutputMetaMap(const MetaMap &output_meta_map);
+       void setConfidenceThreshold(const double &confidence_threshold);
 };
 
 } // machine_learning
diff --git a/mv_machine_learning/landmark_detection/include/landmark_detection_config.h b/mv_machine_learning/landmark_detection/include/landmark_detection_config.h
new file mode 100644 (file)
index 0000000..c8fa1b0
--- /dev/null
@@ -0,0 +1,69 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LANDMARK_DETECTION_CONFIG_H__
+#define __LANDMARK_DETECTION_CONFIG_H__
+
+#include <mv_common.h>
+#include "mv_private.h"
+#include "EngineConfig.h"
+
+#include "MetaParser.h"
+#include "landmark_detection_type.h"
+
+namespace mediavision
+{
+namespace machine_learning
+{
+class LandmarkDetectionConfig
+{
+private:
+       std::unique_ptr<MetaParser> _parser;
+       LandmarkDetectionTaskType _task_type { LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE };
+       std::string _defaultModelName;
+       std::string _modelFilePath;
+       std::string _modelMetaFilePath;
+       std::string _modelDefaultPath;
+       std::string _modelLabelFilePath;
+       int _backendType {};
+       int _targetDeviceType {};
+       double _confidence_threshold {};
+
+public:
+       LandmarkDetectionConfig();
+       virtual ~LandmarkDetectionConfig() = default;
+
+       void setUserModel(const std::string &model_file, const std::string &meta_file, const std::string &label_file);
+       void parseConfigFile(const std::string &configFilePath);
+       void parseMetaFile();
+       void setTaskType(LandmarkDetectionTaskType task_type);
+       void setBackendType(int backend_type);
+       void setTargetDeviceType(int device_type);
+       LandmarkDetectionTaskType getTaskType() const;
+       const std::string &getDefaultModelName() const;
+       const std::string &getModelFilePath() const;
+       const std::string &getLabelFilePath() const;
+       MetaMap &getInputMetaMap() const;
+       MetaMap &getOutputMetaMap() const;
+       double getConfidenceThreshold() const;
+       int getBackendType() const;
+       int getTargetDeviceType() const;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index 13ce7c5..18a3d2c 100644 (file)
@@ -39,11 +39,7 @@ struct LandmarkDetectionResult {
        std::vector<std::string> labels;
 };
 
-enum class LandmarkDetectionTaskType {
-       LANDMARK_DETECTION_TASK_NONE = 0,
-       FLD_TWEAK_CNN,
-       PLD_CPM
-};
+enum class LandmarkDetectionTaskType { LANDMARK_DETECTION_TASK_NONE = 0, FLD_TWEAK_CNN, PLD_CPM };
 
 }
 }
index ba9defb..5a058ef 100644 (file)
@@ -35,7 +35,7 @@ private:
        LandmarkDetectionResult _result;
 
 public:
-       PldCpm(LandmarkDetectionTaskType task_type);
+       PldCpm(LandmarkDetectionTaskType task_type, std::shared_ptr<LandmarkDetectionConfig> config);
        ~PldCpm();
 
        LandmarkDetectionResult &result() override;
index e8879b6..767e488 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "EngineConfig.h"
 #include "itask.h"
+#include "landmark_detection_config.h"
 #include "pld_cpm.h"
 
 namespace mediavision
@@ -31,11 +32,8 @@ template<typename T, typename V> class PoseLandmarkAdapter : public mediavision:
 {
 private:
        std::unique_ptr<LandmarkDetection> _landmark_detection;
+       std::shared_ptr<LandmarkDetectionConfig> _config;
        T _source;
-       std::string _model_name;
-       std::string _model_file;
-       std::string _meta_file;
-       std::string _label_file;
        const std::string _config_file_name = "pose_landmark.json";
 
        void create(LandmarkDetectionTaskType task_type);
index 5f649d7..f912d21 100644 (file)
@@ -29,15 +29,13 @@ namespace machine_learning
 {
 template<typename T, typename V> FacialLandmarkAdapter<T, V>::FacialLandmarkAdapter() : _source()
 {
-       auto config = make_unique<EngineConfig>(MV_CONFIG_PATH + _config_file_name);
+       _config = make_shared<LandmarkDetectionConfig>();
+       _config->parseConfigFile(_config_file_name);
 
-       string defaultModelName;
+       LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
 
-       int ret = config->getStringAttribute(MV_LANDMARK_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get default model name.");
-
-       create(convertToTaskType(defaultModelName));
+       create(model_type);
+       _config->setTaskType(model_type);
 }
 
 template<typename T, typename V> FacialLandmarkAdapter<T, V>::~FacialLandmarkAdapter()
@@ -54,7 +52,7 @@ template<typename T, typename V> void FacialLandmarkAdapter<T, V>::create(Landma
        }
 
        if (task_type == LandmarkDetectionTaskType::FLD_TWEAK_CNN)
-               _landmark_detection = make_unique<FldTweakCnn>(task_type);
+               _landmark_detection = make_unique<FldTweakCnn>(task_type, _config);
 }
 
 template<typename T, typename V>
@@ -77,24 +75,19 @@ void FacialLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const cha
                                                                                           const char *model_name)
 {
        try {
-               create(convertToTaskType(model_name));
+               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
+
+               _config->setUserModel(model_file, meta_file, label_file);
+               create(model_type);
+               _config->setTaskType(model_type);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }
 
-       if (model_file)
-               _model_file = model_file;
-       if (meta_file)
-               _meta_file = meta_file;
-       if (label_file)
-               _label_file = label_file;
-
-       if (_model_file.empty() && _meta_file.empty()) {
+       if (!model_file && !meta_file) {
                LOGW("Given model info is invalid so default model info will be used instead.");
                return;
        }
-
-       _landmark_detection->setUserModel(_model_file, _meta_file, _label_file);
 }
 
 template<typename T, typename V>
@@ -105,7 +98,7 @@ void FacialLandmarkAdapter<T, V>::setEngineInfo(const char *engine_type, const c
 
 template<typename T, typename V> void FacialLandmarkAdapter<T, V>::configure()
 {
-       _landmark_detection->configure(_config_file_name);
+       _landmark_detection->configure();
 }
 
 template<typename T, typename V> void FacialLandmarkAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
index b391f26..71c451e 100644 (file)
@@ -19,7 +19,6 @@
 #include <algorithm>
 
 #include "machine_learning_exception.h"
-#include "mv_landmark_detection_config.h"
 #include "fld_tweak_cnn.h"
 #include "Postprocess.h"
 
@@ -31,7 +30,8 @@ namespace mediavision
 {
 namespace machine_learning
 {
-FldTweakCnn::FldTweakCnn(LandmarkDetectionTaskType task_type) : LandmarkDetection(task_type), _result()
+FldTweakCnn::FldTweakCnn(LandmarkDetectionTaskType task_type, std::shared_ptr<LandmarkDetectionConfig> config)
+               : LandmarkDetection(task_type, config), _result()
 {}
 
 FldTweakCnn::~FldTweakCnn()
@@ -49,7 +49,7 @@ LandmarkDetectionResult &FldTweakCnn::result()
 
        LandmarkDetection::getOutputNames(names);
 
-       auto scoreMetaInfo = _parser->getOutputMetaMap().at(names[0]);
+       auto scoreMetaInfo = _config->getOutputMetaMap().at(names[0]);
        auto decodingLandmark =
                        static_pointer_cast<DecodingLandmark>(scoreMetaInfo->decodingTypeMap[DecodingType::LANDMARK]);
 
index 83b2aad..6b6af1e 100644 (file)
@@ -22,7 +22,6 @@
 
 #include "machine_learning_exception.h"
 #include "mv_machine_learning_common.h"
-#include "mv_landmark_detection_config.h"
 #include "landmark_detection.h"
 
 using namespace std;
@@ -35,13 +34,11 @@ namespace mediavision
 {
 namespace machine_learning
 {
-LandmarkDetection::LandmarkDetection(LandmarkDetectionTaskType task_type)
-               : _task_type(task_type)
-               , _backendType(MV_INFERENCE_BACKEND_NONE)
-               , _targetDeviceType(MV_INFERENCE_TARGET_DEVICE_NONE)
+LandmarkDetection::LandmarkDetection(LandmarkDetectionTaskType task_type, shared_ptr<LandmarkDetectionConfig> config)
+               : _task_type(task_type), _config(config)
 {
        _inference = make_unique<Inference>();
-       _parser = make_unique<LandmarkDetectionParser>();
+       loadLabel();
 }
 
 void LandmarkDetection::preDestroy()
@@ -77,23 +74,25 @@ void LandmarkDetection::getDeviceList(const char *engine_type)
        _valid_devices.push_back("gpu");
 }
 
-void LandmarkDetection::setEngineInfo(std::string engine_type, std::string device_type)
+void LandmarkDetection::setEngineInfo(string engine_type_name, string device_type_name)
 {
-       if (engine_type.empty() || device_type.empty())
+       if (engine_type_name.empty() || device_type_name.empty())
                throw InvalidParameter("Invalid engine info.");
 
-       transform(engine_type.begin(), engine_type.end(), engine_type.begin(), ::toupper);
-       transform(device_type.begin(), device_type.end(), device_type.begin(), ::toupper);
-
-       _backendType = GetBackendType(engine_type);
-       _targetDeviceType = GetDeviceType(device_type);
+       transform(engine_type_name.begin(), engine_type_name.end(), engine_type_name.begin(), ::toupper);
+       transform(device_type_name.begin(), device_type_name.end(), device_type_name.begin(), ::toupper);
 
-       LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type.c_str(), GetBackendType(engine_type),
-                device_type.c_str(), GetDeviceType(device_type));
+       int engine_type = GetBackendType(engine_type_name);
+       int device_type = GetDeviceType(device_type_name);
 
-       if (_backendType == MEDIA_VISION_ERROR_INVALID_PARAMETER ||
-               _targetDeviceType == MEDIA_VISION_ERROR_INVALID_PARAMETER)
+       if (engine_type == MEDIA_VISION_ERROR_INVALID_PARAMETER || device_type == MEDIA_VISION_ERROR_INVALID_PARAMETER)
                throw InvalidParameter("backend or target device type not found.");
+
+       _config->setBackendType(engine_type);
+       _config->setTargetDeviceType(device_type);
+
+       LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type_name.c_str(), engine_type,
+                device_type_name.c_str(), device_type);
 }
 
 void LandmarkDetection::getNumberOfEngines(unsigned int *number_of_engines)
@@ -154,27 +153,18 @@ void LandmarkDetection::getDeviceType(const char *engine_type, const unsigned in
        *device_type = const_cast<char *>(_valid_devices[device_index].data());
 }
 
-void LandmarkDetection::setUserModel(string model_file, string meta_file, string label_file)
-{
-       _modelFilePath = model_file;
-       _modelMetaFilePath = meta_file;
-       _modelLabelFilePath = label_file;
-}
-
-static bool IsJsonFile(const string &fileName)
-{
-       return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
-}
-
 void LandmarkDetection::loadLabel()
 {
+       if (_config->getLabelFilePath().empty())
+               return;
+
        ifstream readFile;
 
        _labels.clear();
-       readFile.open(_modelLabelFilePath.c_str());
+       readFile.open(_config->getLabelFilePath().c_str());
 
        if (readFile.fail())
-               throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
+               throw InvalidOperation("Fail to open " + _config->getLabelFilePath() + " file.");
 
        string line;
 
@@ -184,95 +174,24 @@ void LandmarkDetection::loadLabel()
        readFile.close();
 }
 
-void LandmarkDetection::parseMetaFile(const string &meta_file_name)
+void LandmarkDetection::configure()
 {
-       int ret = MEDIA_VISION_ERROR_NONE;
-       _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + meta_file_name);
-
-       if (_backendType == MV_INFERENCE_BACKEND_NONE) {
-               ret = _config->getIntegerAttribute(MV_LANDMARK_DETECTION_BACKEND_TYPE, &_backendType);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get backend engine type.");
-       }
-
-       if (_targetDeviceType == MV_INFERENCE_TARGET_DEVICE_NONE) {
-               ret = _config->getIntegerAttribute(MV_LANDMARK_DETECTION_TARGET_DEVICE_TYPE, &_targetDeviceType);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get target device type.");
-       }
-
-       ret = _config->getStringAttribute(MV_LANDMARK_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get model default path");
-
-       if (_modelFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_LANDMARK_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get model file path");
-       }
-
-       _modelFilePath = _modelDefaultPath + _modelFilePath;
-       LOGI("model file path = %s", _modelFilePath.c_str());
-
-       if (_modelMetaFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_LANDMARK_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get model meta file path");
-
-               if (_modelMetaFilePath.empty())
-                       throw InvalidOperation("Model meta file doesn't exist.");
-
-               if (!IsJsonFile(_modelMetaFilePath))
-                       throw InvalidOperation("Model meta file should be json");
-       }
-
-       _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
-       LOGI("meta file path = %s", _modelMetaFilePath.c_str());
-
-       _parser->setTaskType(static_cast<int>(_task_type));
-       _parser->load(_modelMetaFilePath);
-
-       if (_modelLabelFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_LANDMARK_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get label file path");
-
-               if (_modelLabelFilePath.empty()) {
-                       LOGW("Label doesn't exist.");
-                       return;
-               }
-       }
-
-       _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
-       LOGI("label file path = %s", _modelLabelFilePath.c_str());
-
-       ret = _config->getDoubleAttribute(MV_LANDMARK_DETECTION_CONFIDENCE_THRESHOLD, &_confidence_threshold);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               LOGW("threshold value doesn't exist.");
-
-       loadLabel();
-}
-
-void LandmarkDetection::configure(const string &configFile)
-{
-       parseMetaFile(configFile);
-
-       int ret = _inference->bind(_backendType, _targetDeviceType);
+       int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to bind a backend engine.");
 }
 
 void LandmarkDetection::prepare()
 {
-       int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
+       int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to configure input tensor info from meta file.");
 
-       ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
+       ret = _inference->configureOutputMetaInfo(_config->getOutputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to configure output tensor info from meta file.");
 
-       _inference->configureModelFiles("", _modelFilePath, "");
+       _inference->configureModelFiles("", _config->getModelFilePath(), "");
 
        // Request to load model files to a backend engine.
        ret = _inference->load();
@@ -292,7 +211,7 @@ shared_ptr<MetaInfo> LandmarkDetection::getInputMetaInfo()
        auto tensor_buffer_iter = tensor_info_map.begin();
 
        // Get the meta information corresponding to a given input tensor name.
-       return _parser->getInputMetaMap()[tensor_buffer_iter->first];
+       return _config->getInputMetaMap()[tensor_buffer_iter->first];
 }
 
 template<typename T>
diff --git a/mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp b/mv_machine_learning/landmark_detection/src/landmark_detection_config.cpp
new file mode 100644 (file)
index 0000000..90c9767
--- /dev/null
@@ -0,0 +1,175 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "machine_learning_exception.h"
+#include "mv_landmark_detection_config.h"
+#include "LandmarkDetectionParser.h"
+#include "landmark_detection_config.h"
+
+using namespace std;
+using namespace MediaVision::Common;
+using namespace mediavision::machine_learning;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+LandmarkDetectionConfig::LandmarkDetectionConfig()
+{
+       _parser = make_unique<LandmarkDetectionParser>();
+}
+
+void LandmarkDetectionConfig::setTaskType(LandmarkDetectionTaskType task_type)
+{
+       _task_type = task_type;
+}
+
+void LandmarkDetectionConfig::setBackendType(int backend_type)
+{
+       _backendType = backend_type;
+}
+
+void LandmarkDetectionConfig::setTargetDeviceType(int device_type)
+{
+       _targetDeviceType = device_type;
+}
+
+LandmarkDetectionTaskType LandmarkDetectionConfig::getTaskType() const
+{
+       return _task_type;
+}
+
+const std::string &LandmarkDetectionConfig::getDefaultModelName() const
+{
+       return _defaultModelName;
+}
+
+const std::string &LandmarkDetectionConfig::getModelFilePath() const
+{
+       return _modelFilePath;
+}
+
+const std::string &LandmarkDetectionConfig::getLabelFilePath() const
+{
+       return _modelLabelFilePath;
+}
+
+MetaMap &LandmarkDetectionConfig::getInputMetaMap() const
+{
+       return _parser->getInputMetaMap();
+}
+
+MetaMap &LandmarkDetectionConfig::getOutputMetaMap() const
+{
+       return _parser->getOutputMetaMap();
+}
+
+double LandmarkDetectionConfig::getConfidenceThreshold() const
+{
+       return _confidence_threshold;
+}
+
+int LandmarkDetectionConfig::getBackendType() const
+{
+       return _backendType;
+}
+
+int LandmarkDetectionConfig::getTargetDeviceType() const
+{
+       return _targetDeviceType;
+}
+
+void LandmarkDetectionConfig::setUserModel(const string &model_file, const string &meta_file, const string &label_file)
+{
+       if (!model_file.empty())
+               _modelFilePath = _modelDefaultPath + model_file;
+       if (meta_file.empty())
+               _modelMetaFilePath = _modelDefaultPath + meta_file;
+       if (label_file.empty())
+               _modelLabelFilePath = _modelDefaultPath + label_file;
+}
+
+static bool IsJsonFile(const string &fileName)
+{
+       return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
+}
+
+void LandmarkDetectionConfig::parseConfigFile(const std::string &configFilePath)
+{
+       auto config = make_unique<EngineConfig>(MV_CONFIG_PATH + configFilePath);
+
+       int ret = config->getStringAttribute(MV_LANDMARK_DETECTION_DEFAULT_MODEL_NAME, &_defaultModelName);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get default model name.");
+
+       if (_backendType == MV_INFERENCE_BACKEND_NONE) {
+               ret = config->getIntegerAttribute(MV_LANDMARK_DETECTION_BACKEND_TYPE, &_backendType);
+               if (ret != MEDIA_VISION_ERROR_NONE)
+                       throw InvalidOperation("Fail to get backend engine type.");
+       }
+
+       if (_targetDeviceType == MV_INFERENCE_TARGET_DEVICE_NONE) {
+               ret = config->getIntegerAttribute(MV_LANDMARK_DETECTION_TARGET_DEVICE_TYPE, &_targetDeviceType);
+               if (ret != MEDIA_VISION_ERROR_NONE)
+                       throw InvalidOperation("Fail to get target device type.");
+       }
+
+       ret = config->getStringAttribute(MV_LANDMARK_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get model default path");
+
+       ret = config->getStringAttribute(MV_LANDMARK_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get model file path");
+
+       _modelFilePath = _modelDefaultPath + _modelFilePath;
+       LOGI("model file path = %s", _modelFilePath.c_str());
+
+       ret = config->getStringAttribute(MV_LANDMARK_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get model meta file path");
+
+       if (_modelMetaFilePath.empty())
+               throw InvalidOperation("Model meta file doesn't exist.");
+
+       if (!IsJsonFile(_modelMetaFilePath))
+               throw InvalidOperation("Model meta file should be json");
+
+       _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
+       LOGI("meta file path = %s", _modelMetaFilePath.c_str());
+
+       _parser->load(_modelMetaFilePath);
+
+       ret = config->getDoubleAttribute(MV_LANDMARK_DETECTION_CONFIDENCE_THRESHOLD, &_confidence_threshold);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               LOGW("threshold value doesn't exist.");
+
+       ret = config->getStringAttribute(MV_LANDMARK_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get label file path");
+
+       if (_modelLabelFilePath.empty()) {
+               LOGW("Label doesn't exist.");
+               return;
+       }
+
+       _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
+       LOGI("label file path = %s", _modelLabelFilePath.c_str());
+}
+
+}
+}
\ No newline at end of file
index d16fe17..91f3aeb 100644 (file)
@@ -31,7 +31,8 @@ namespace mediavision
 {
 namespace machine_learning
 {
-PldCpm::PldCpm(LandmarkDetectionTaskType task_type) : LandmarkDetection(task_type), _result()
+PldCpm::PldCpm(LandmarkDetectionTaskType task_type, std::shared_ptr<LandmarkDetectionConfig> config)
+               : LandmarkDetection(task_type, config), _result()
 {}
 
 PldCpm::~PldCpm()
@@ -47,7 +48,7 @@ LandmarkDetectionResult &PldCpm::result()
 
        LandmarkDetection::getOutputNames(names);
 
-       auto scoreMetaInfo = _parser->getOutputMetaMap().at(names[0]);
+       auto scoreMetaInfo = _config->getOutputMetaMap().at(names[0]);
        auto decodingLandmark =
                        static_pointer_cast<DecodingLandmark>(scoreMetaInfo->decodingTypeMap[DecodingType::LANDMARK]);
 
@@ -82,7 +83,7 @@ LandmarkDetectionResult &PldCpm::result()
                for (auto y = 0; y < heatMapHeight; ++y) {
                        for (auto x = 0; x < heatMapWidth; ++x) {
                                auto score = score_tensor[y * heatMapWidth * heatMapChannel + x * heatMapChannel + c];
-                               if (score < _confidence_threshold)
+                               if (score < _config->getConfidenceThreshold())
                                        continue;
 
                                if (max_score < score) {
index 2bd4bc3..91c142d 100644 (file)
@@ -29,15 +29,13 @@ namespace machine_learning
 {
 template<typename T, typename V> PoseLandmarkAdapter<T, V>::PoseLandmarkAdapter() : _source()
 {
-       auto config = make_unique<EngineConfig>(MV_CONFIG_PATH + _config_file_name);
+       _config = make_shared<LandmarkDetectionConfig>();
+       _config->parseConfigFile(_config_file_name);
 
-       string defaultModelName;
+       LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
 
-       int ret = config->getStringAttribute(MV_LANDMARK_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get default model name.");
-
-       create(convertToTaskType(defaultModelName));
+       create(model_type);
+       _config->setTaskType(model_type);
 }
 
 template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter()
@@ -56,7 +54,7 @@ template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(Landmark
        }
 
        if (task_type == LandmarkDetectionTaskType::PLD_CPM)
-               _landmark_detection = make_unique<PldCpm>(task_type);
+               _landmark_detection = make_unique<PldCpm>(task_type, _config);
 }
 
 template<typename T, typename V>
@@ -80,24 +78,19 @@ void PoseLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char
                                                                                         const char *model_name)
 {
        try {
-               create(convertToTaskType(model_name));
+               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
+
+               _config->setUserModel(model_file, meta_file, label_file);
+               create(model_type);
+               _config->setTaskType(model_type);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }
 
-       if (model_file)
-               _model_file = model_file;
-       if (meta_file)
-               _meta_file = meta_file;
-       if (label_file)
-               _label_file = label_file;
-
-       if (_model_file.empty() && _meta_file.empty()) {
+       if (!model_file && !meta_file) {
                LOGW("Given model info is invalid so default model info will be used instead.");
                return;
        }
-
-       _landmark_detection->setUserModel(_model_file, _meta_file, _label_file);
 }
 
 template<typename T, typename V>
@@ -108,7 +101,7 @@ void PoseLandmarkAdapter<T, V>::setEngineInfo(const char *engine_type, const cha
 
 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::configure()
 {
-       _landmark_detection->configure(_config_file_name);
+       _landmark_detection->configure();
 }
 
 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)