mv_machine_learning: drop parsing dependency from object detection
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Mon, 6 Nov 2023 09:59:32 +0000 (18:59 +0900)
committerInki Dae <inki.dae@samsung.com>
Tue, 14 Nov 2023 07:39:56 +0000 (16:39 +0900)
[Issue type] : code refactoring

Drop the configuration and meta file parsing dependency from ObjectDetection
class.

Until now, the concrete class of each task group got the task group
configuration information from its own configuration file, and also
included a MetaParser class object to get the tensor information,
which is corresponding to a given model file.

However, these dependencies led code smell, divergent change[1] even though
the concrete class has no any dependency from parsing the configuration and
meta files - needed only information after parsed.

As a first refactoring work, this patch extracts parsing portion from
ObjectDetection class and introduces as a new class, ObjectDetectionConfig
class.

With this, adapter classes of the object detection task group will parse
the configuration and meta files before creating ObjectDetection class.
And then it will create ObjectDetection class with needed information.
As a result, we could manage the ObjectDetection class without any
dependency on parsing work.

[1] https://refactoring.guru/smells/divergent-change

Change-Id: I526d58e0012ba4daae8eb6eb7ef99bb3cf8545d6
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
15 files changed:
mv_machine_learning/object_detection/include/face_detection_adapter.h
mv_machine_learning/object_detection/include/iobject_detection.h
mv_machine_learning/object_detection/include/mobilenet_v1_ssd.h
mv_machine_learning/object_detection/include/mobilenet_v2_ssd.h
mv_machine_learning/object_detection/include/object_detection.h
mv_machine_learning/object_detection/include/object_detection_adapter.h
mv_machine_learning/object_detection/include/object_detection_config.h [new file with mode: 0644]
mv_machine_learning/object_detection/include/object_detection_external.h
mv_machine_learning/object_detection/src/face_detection_adapter.cpp
mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp
mv_machine_learning/object_detection/src/mobilenet_v2_ssd.cpp
mv_machine_learning/object_detection/src/object_detection.cpp
mv_machine_learning/object_detection/src/object_detection_adapter.cpp
mv_machine_learning/object_detection/src/object_detection_config.cpp [new file with mode: 0644]
mv_machine_learning/object_detection/src/object_detection_external.cpp

index c6f8dc9..c93f441 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "EngineConfig.h"
 #include "itask.h"
+#include "object_detection_config.h"
 #include "mobilenet_v1_ssd.h"
 
 namespace mediavision
@@ -31,11 +32,8 @@ template<typename T, typename V> class FaceDetectionAdapter : public mediavision
 {
 private:
        std::unique_ptr<IObjectDetection> _object_detection;
+       std::shared_ptr<ObjectDetectionConfig> _config;
        T _source;
-       std::string _model_name;
-       std::string _model_file;
-       std::string _meta_file;
-       std::string _label_file;
        const std::string _config_file_name = "face_detection.json";
 
        void create(ObjectDetectionTaskType task_type);
index 72381a3..d05b53a 100644 (file)
@@ -32,13 +32,12 @@ public:
 
        virtual void preDestroy() = 0;
        virtual ObjectDetectionTaskType getTaskType() = 0;
-       virtual void setUserModel(std::string model_file, std::string meta_file, std::string label_file) = 0;
        virtual void setEngineInfo(std::string engine_type, std::string device_type) = 0;
        virtual void getNumberOfEngines(unsigned int *number_of_engines) = 0;
        virtual void getEngineType(unsigned int engine_index, char **engine_type) = 0;
        virtual void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) = 0;
        virtual void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) = 0;
-       virtual void configure(std::string configFile) = 0;
+       virtual void configure() = 0;
        virtual void prepare() = 0;
        virtual void perform(mv_source_h &mv_src) = 0;
        virtual void performAsync(ObjectDetectionInput &input) = 0;
index ddbe3a9..2a1245b 100644 (file)
@@ -35,7 +35,7 @@ private:
        ObjectDetectionResult _result;
 
 public:
-       MobilenetV1Ssd(ObjectDetectionTaskType task_type);
+       MobilenetV1Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<ObjectDetectionConfig> config);
        ~MobilenetV1Ssd();
 
        ObjectDetectionResult &result() override;
index ac61fb9..e76ac54 100644 (file)
@@ -41,7 +41,7 @@ private:
        Box decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Box &box, cv::Rect2f &anchor);
 
 public:
-       MobilenetV2Ssd(ObjectDetectionTaskType task_type);
+       MobilenetV2Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<ObjectDetectionConfig> config);
        ~MobilenetV2Ssd();
 
        ObjectDetectionResult &result() override;
index 22c0323..62f0b2a 100644 (file)
@@ -30,7 +30,9 @@
 #include "inference_engine_common_impl.h"
 #include "Inference.h"
 #include "object_detection_type.h"
+#include "MetaParser.h"
 #include "ObjectDetectionParser.h"
+#include "object_detection_config.h"
 #include "machine_learning_preprocess.h"
 #include "iobject_detection.h"
 #include "async_manager.h"
@@ -57,38 +59,30 @@ private:
 
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
-       std::unique_ptr<MediaVision::Common::EngineConfig> _config;
-       std::unique_ptr<MetaParser> _parser;
+       std::shared_ptr<ObjectDetectionConfig> _config;
        std::vector<std::string> _labels;
        std::vector<std::string> _valid_backends;
        std::vector<std::string> _valid_devices;
        Preprocess _preprocess;
-       std::string _modelFilePath;
-       std::string _modelMetaFilePath;
-       std::string _modelDefaultPath;
-       std::string _modelLabelFilePath;
-       int _backendType {};
-       int _targetDeviceType {};
 
        void getOutputNames(std::vector<std::string> &names);
        void getOutputTensor(std::string target_name, std::vector<float> &tensor);
-       void parseMetaFile(std::string meta_file_name);
        template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
        virtual ObjectDetectionResult &result() = 0;
 
 public:
-       explicit ObjectDetection(ObjectDetectionTaskType task_type);
+       explicit ObjectDetection(ObjectDetectionTaskType task_type, std::shared_ptr<ObjectDetectionConfig> config);
        virtual ~ObjectDetection() = default;
 
        void preDestroy() override;
        ObjectDetectionTaskType getTaskType() override;
-       void setUserModel(std::string model_file, std::string meta_file, std::string label_file) override;
-       void setEngineInfo(std::string engine_type, std::string device_type) override;
+       void setUserModel(std::string model_file, std::string meta_file, std::string label_file);
+       void setEngineInfo(std::string engine_type_name, std::string device_type_name) override;
        void getNumberOfEngines(unsigned int *number_of_engines) override;
        void getEngineType(unsigned int engine_index, char **engine_type) override;
        void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) override;
        void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) override;
-       void configure(std::string configFile) override;
+       void configure() override;
        void prepare() override;
        void perform(mv_source_h &mv_src) override;
        void performAsync(ObjectDetectionInput &input) override;
index 41be841..c42d6b4 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "EngineConfig.h"
 #include "itask.h"
+#include "object_detection_config.h"
 #include "mobilenet_v1_ssd.h"
 #include "mobilenet_v2_ssd.h"
 
@@ -32,12 +33,9 @@ template<typename T, typename V> class ObjectDetectionAdapter : public mediavisi
 {
 private:
        std::unique_ptr<IObjectDetection> _object_detection;
+       std::shared_ptr<ObjectDetectionConfig> _config;
        T _source;
-       std::string _model_name;
-       std::string _model_file;
-       std::string _meta_file;
-       std::string _label_file;
-       const std::string _meta_file_name = "object_detection.json";
+       const std::string _config_file_name = "object_detection.json";
 
        void create(ObjectDetectionTaskType task_type);
        ObjectDetectionTaskType convertToTaskType(std::string model_name);
diff --git a/mv_machine_learning/object_detection/include/object_detection_config.h b/mv_machine_learning/object_detection/include/object_detection_config.h
new file mode 100644 (file)
index 0000000..4f621b8
--- /dev/null
@@ -0,0 +1,65 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OBJECT_DETECTION_CONFIG_H__
+#define __OBJECT_DETECTION_CONFIG_H__
+
+#include <mv_common.h>
+#include "mv_private.h"
+#include "EngineConfig.h"
+
+#include "MetaParser.h"
+#include "object_detection_type.h"
+
+namespace mediavision
+{
+namespace machine_learning
+{
+class ObjectDetectionConfig
+{
+private:
+       std::unique_ptr<MetaParser> _parser;
+       std::string _defaultModelName;
+       std::string _modelFilePath;
+       std::string _modelMetaFilePath;
+       std::string _modelDefaultPath;
+       std::string _modelLabelFilePath;
+       int _backendType {};
+       int _targetDeviceType {};
+
+public:
+       ObjectDetectionConfig();
+       virtual ~ObjectDetectionConfig() = default;
+
+       void setUserModel(const std::string &model_file, const std::string &meta_file, const std::string &label_file);
+       void parseConfigFile(const std::string &configFilePath);
+       void parseMetaFile();
+       void setBackendType(int backend_type);
+       void setTargetDeviceType(int device_type);
+       const std::string &getDefaultModelName() const;
+       const std::string &getModelFilePath() const;
+       const std::string &getLabelFilePath() const;
+       MetaMap &getInputMetaMap() const;
+       MetaMap &getOutputMetaMap() const;
+       int getBackendType() const;
+       int getTargetDeviceType() const;
+       void loadMetaFile(ObjectDetectionTaskType task_type);
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index 72a0b5a..ea7553b 100644 (file)
@@ -24,6 +24,7 @@
 #include <opencv2/imgproc.hpp>
 
 #include "object_detection_type.h"
+#include "object_detection_config.h"
 #include "iobject_detection.h"
 
 namespace mediavision
@@ -44,13 +45,12 @@ public:
 
        void preDestroy() override;
        ObjectDetectionTaskType getTaskType() override;
-       void setUserModel(std::string model_file, std::string meta_file, std::string label_file) override;
        void setEngineInfo(std::string engine_type, std::string device_type) override;
        void getNumberOfEngines(unsigned int *number_of_engines) override;
        void getEngineType(unsigned int engine_index, char **engine_type) override;
        void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) override;
        void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) override;
-       void configure(std::string configFile) override;
+       void configure() override;
        void prepare() override;
        void perform(mv_source_h &mv_src) override;
        void performAsync(ObjectDetectionInput &input) override;
index fd4d98c..b34c678 100644 (file)
@@ -29,14 +29,11 @@ namespace machine_learning
 {
 template<typename T, typename V> FaceDetectionAdapter<T, V>::FaceDetectionAdapter() : _source()
 {
-       auto config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + _config_file_name);
-       string defaultModelName;
+       _config = make_shared<ObjectDetectionConfig>();
+       _config->parseConfigFile(_config_file_name);
 
-       int ret = config->getStringAttribute(MV_OBJECT_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get default model name.");
-
-       create(convertToTaskType(defaultModelName.c_str()));
+       ObjectDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
+       create(model_type);
 }
 
 template<typename T, typename V> FaceDetectionAdapter<T, V>::~FaceDetectionAdapter()
@@ -52,12 +49,13 @@ template<typename T, typename V> void FaceDetectionAdapter<T, V>::create(ObjectD
                // If default task type is same as a given one then skip.
                if (_object_detection->getTaskType() == task_type)
                        return;
-
-               _object_detection.reset();
        }
 
+       // if model name is changed by user then reallocate the parser and reload the meta file corresponding to the model name.
+       _config->loadMetaFile(task_type);
+
        if (task_type == ObjectDetectionTaskType::FD_MOBILENET_V1_SSD)
-               _object_detection = make_unique<MobilenetV1Ssd>(task_type);
+               _object_detection = make_unique<MobilenetV1Ssd>(task_type, _config);
        // TODO.
 }
 
@@ -83,24 +81,18 @@ void FaceDetectionAdapter<T, V>::setModelInfo(const char *model_file, const char
                                                                                          const char *model_name)
 {
        try {
-               create(convertToTaskType(model_name));
+               _config->setUserModel(model_file, meta_file, label_file);
+
+               ObjectDetectionTaskType model_type = convertToTaskType(model_name);
+               create(model_type);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }
 
-       if (model_file)
-               _model_file = model_file;
-       if (meta_file)
-               _meta_file = meta_file;
-       if (label_file)
-               _label_file = label_file;
-
-       if (_model_file.empty() && _meta_file.empty() && _label_file.empty()) {
+       if (!model_file && !meta_file) {
                LOGW("Given model info is invalid so default model info will be used instead.");
                return;
        }
-
-       _object_detection->setUserModel(_model_file, _meta_file, _label_file);
 }
 
 template<typename T, typename V>
@@ -111,7 +103,7 @@ void FaceDetectionAdapter<T, V>::setEngineInfo(const char *engine_type, const ch
 
 template<typename T, typename V> void FaceDetectionAdapter<T, V>::configure()
 {
-       _object_detection->configure(_config_file_name);
+       _object_detection->configure();
 }
 
 template<typename T, typename V> void FaceDetectionAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
index 51e6e0a..dd50548 100644 (file)
@@ -31,7 +31,8 @@ namespace mediavision
 {
 namespace machine_learning
 {
-MobilenetV1Ssd::MobilenetV1Ssd(ObjectDetectionTaskType task_type) : ObjectDetection(task_type), _result()
+MobilenetV1Ssd::MobilenetV1Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<ObjectDetectionConfig> config)
+               : ObjectDetection(task_type, config), _result()
 {}
 
 MobilenetV1Ssd::~MobilenetV1Ssd()
@@ -60,7 +61,7 @@ ObjectDetectionResult &MobilenetV1Ssd::result()
        vector<float> score_tensor;
        map<float, unsigned int, std::greater<float> > sorted_score;
 
-       auto scoreMetaInfo = _parser->getOutputMetaMap().at(names[2]);
+       auto scoreMetaInfo = _config->getOutputMetaMap().at(names[2]);
        auto decodingScore = static_pointer_cast<DecodingScore>(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]);
 
        // TFLite_Detection_PostProcess:2
@@ -72,7 +73,7 @@ ObjectDetectionResult &MobilenetV1Ssd::result()
                sorted_score[score_tensor[idx]] = idx;
        }
 
-       auto boxMetaInfo = _parser->getOutputMetaMap().at(names[0]);
+       auto boxMetaInfo = _config->getOutputMetaMap().at(names[0]);
        auto decodingBox = static_pointer_cast<DecodingBox>(boxMetaInfo->decodingTypeMap[DecodingType::BOX]);
        vector<float> box_tensor;
 
index ef1e312..9c5373f 100644 (file)
@@ -32,7 +32,8 @@ namespace mediavision
 {
 namespace machine_learning
 {
-MobilenetV2Ssd::MobilenetV2Ssd(ObjectDetectionTaskType task_type) : ObjectDetection(task_type), _result()
+MobilenetV2Ssd::MobilenetV2Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<ObjectDetectionConfig> config)
+               : ObjectDetection(task_type, config), _result()
 {}
 
 MobilenetV2Ssd::~MobilenetV2Ssd()
@@ -184,10 +185,10 @@ ObjectDetectionResult &MobilenetV2Ssd::result()
        // raw_outputs/class_predictions
        ObjectDetection::getOutputTensor(names[1], score_tensor);
 
-       auto scoreMetaInfo = _parser->getOutputMetaMap().at(names[1]);
+       auto scoreMetaInfo = _config->getOutputMetaMap().at(names[1]);
        auto decodingScore = static_pointer_cast<DecodingScore>(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]);
 
-       auto boxMetaInfo = _parser->getOutputMetaMap().at(names[0]);
+       auto boxMetaInfo = _config->getOutputMetaMap().at(names[0]);
        auto decodingBox = static_pointer_cast<DecodingBox>(boxMetaInfo->decodingTypeMap[DecodingType::BOX]);
        auto anchorParam = static_pointer_cast<BoxAnchorParam>(decodingBox->decodingInfoMap[BoxDecodingType::SSD_ANCHOR]);
        unsigned int number_of_objects = scoreMetaInfo->dims[2]; // Shape is 1 x 2034 x 91
index e2336dd..3c9622e 100644 (file)
@@ -34,13 +34,11 @@ namespace mediavision
 {
 namespace machine_learning
 {
-ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type)
-               : _task_type(task_type)
-               , _backendType(MV_INFERENCE_BACKEND_NONE)
-               , _targetDeviceType(MV_INFERENCE_TARGET_DEVICE_NONE)
+ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr<ObjectDetectionConfig> config)
+               : _task_type(task_type), _config(config)
 {
        _inference = make_unique<Inference>();
-       _parser = make_unique<ObjectDetectionParser>();
+       loadLabel();
 }
 
 void ObjectDetection::preDestroy()
@@ -76,23 +74,25 @@ void ObjectDetection::getDeviceList(const char *engine_type)
        _valid_devices.push_back("gpu");
 }
 
-void ObjectDetection::setEngineInfo(std::string engine_type, std::string device_type)
+void ObjectDetection::setEngineInfo(std::string engine_type_name, std::string device_type_name)
 {
-       if (engine_type.empty() || device_type.empty())
+       if (engine_type_name.empty() || device_type_name.empty())
                throw InvalidParameter("Invalid engine info.");
 
-       transform(engine_type.begin(), engine_type.end(), engine_type.begin(), ::toupper);
-       transform(device_type.begin(), device_type.end(), device_type.begin(), ::toupper);
-
-       _backendType = GetBackendType(engine_type);
-       _targetDeviceType = GetDeviceType(device_type);
+       transform(engine_type_name.begin(), engine_type_name.end(), engine_type_name.begin(), ::toupper);
+       transform(device_type_name.begin(), device_type_name.end(), device_type_name.begin(), ::toupper);
 
-       LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type.c_str(), GetBackendType(engine_type),
-                device_type.c_str(), GetDeviceType(device_type));
+       int engine_type = GetBackendType(engine_type_name);
+       int device_type = GetDeviceType(device_type_name);
 
-       if (_backendType == MEDIA_VISION_ERROR_INVALID_PARAMETER ||
-               _targetDeviceType == MEDIA_VISION_ERROR_INVALID_PARAMETER)
+       if (engine_type == MEDIA_VISION_ERROR_INVALID_PARAMETER || device_type == MEDIA_VISION_ERROR_INVALID_PARAMETER)
                throw InvalidParameter("backend or target device type not found.");
+
+       _config->setBackendType(engine_type);
+       _config->setTargetDeviceType(device_type);
+
+       LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type_name.c_str(), engine_type,
+                device_type_name.c_str(), device_type);
 }
 
 void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines)
@@ -153,27 +153,18 @@ void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int
        *device_type = const_cast<char *>(_valid_devices[device_index].data());
 }
 
-void ObjectDetection::setUserModel(string model_file, string meta_file, string label_file)
-{
-       _modelFilePath = model_file;
-       _modelMetaFilePath = meta_file;
-       _modelLabelFilePath = label_file;
-}
-
-static bool IsJsonFile(const string &fileName)
-{
-       return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
-}
-
 void ObjectDetection::loadLabel()
 {
+       if (_config->getLabelFilePath().empty())
+               return;
+
        ifstream readFile;
 
        _labels.clear();
-       readFile.open(_modelLabelFilePath.c_str());
+       readFile.open(_config->getLabelFilePath().c_str());
 
        if (readFile.fail())
-               throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
+               throw InvalidOperation("Fail to open " + _config->getLabelFilePath() + " file.");
 
        string line;
 
@@ -183,89 +174,24 @@ void ObjectDetection::loadLabel()
        readFile.close();
 }
 
-void ObjectDetection::parseMetaFile(string meta_file_name)
+void ObjectDetection::configure()
 {
-       int ret = MEDIA_VISION_ERROR_NONE;
-       _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + meta_file_name);
-
-       if (_backendType == MV_INFERENCE_BACKEND_NONE) {
-               ret = _config->getIntegerAttribute(MV_OBJECT_DETECTION_BACKEND_TYPE, &_backendType);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get backend engine type.");
-       }
-
-       if (_targetDeviceType == MV_INFERENCE_TARGET_DEVICE_NONE) {
-               ret = _config->getIntegerAttribute(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE, &_targetDeviceType);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get target device type.");
-       }
-
-       ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get model default path");
-
-       if (_modelFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get model file path");
-       }
-
-       _modelFilePath = _modelDefaultPath + _modelFilePath;
-       LOGI("model file path = %s", _modelFilePath.c_str());
-
-       if (_modelMetaFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get model meta file path");
-
-               if (_modelMetaFilePath.empty())
-                       throw InvalidOperation("Model meta file doesn't exist.");
-
-               if (!IsJsonFile(_modelMetaFilePath))
-                       throw InvalidOperation("Model meta file should be json");
-       }
-
-       _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
-       LOGI("meta file path = %s", _modelMetaFilePath.c_str());
-
-       _parser->setTaskType(static_cast<int>(_task_type));
-       _parser->load(_modelMetaFilePath);
-
-       if (_modelLabelFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get label file path");
-
-               if (_modelLabelFilePath.empty())
-                       throw InvalidOperation("Model label file doesn't exist.");
-       }
-
-       _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
-       LOGI("label file path = %s", _modelLabelFilePath.c_str());
-
-       loadLabel();
-}
-
-void ObjectDetection::configure(string configFile)
-{
-       parseMetaFile(configFile);
-
-       int ret = _inference->bind(_backendType, _targetDeviceType);
+       int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to bind a backend engine.");
 }
 
 void ObjectDetection::prepare()
 {
-       int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
+       int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to configure input tensor info from meta file.");
 
-       ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
+       ret = _inference->configureOutputMetaInfo(_config->getOutputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to configure output tensor info from meta file.");
 
-       _inference->configureModelFiles("", _modelFilePath, "");
+       _inference->configureModelFiles("", _config->getModelFilePath(), "");
 
        // Request to load model files to a backend engine.
        ret = _inference->load();
@@ -285,7 +211,7 @@ shared_ptr<MetaInfo> ObjectDetection::getInputMetaInfo()
        auto tensor_buffer_iter = tensor_info_map.begin();
 
        // Get the meta information corresponding to a given input tensor name.
-       return _parser->getInputMetaMap()[tensor_buffer_iter->first];
+       return _config->getInputMetaMap()[tensor_buffer_iter->first];
 }
 
 template<typename T>
index 16c4d9f..e2f74d6 100644 (file)
@@ -30,15 +30,11 @@ namespace machine_learning
 {
 template<typename T, typename V> ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source()
 {
-       auto config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + _meta_file_name);
+       _config = make_shared<ObjectDetectionConfig>();
+       _config->parseConfigFile(_config_file_name);
 
-       string defaultModelName;
-
-       int ret = config->getStringAttribute(MV_OBJECT_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get default model name.");
-
-       create(convertToTaskType(defaultModelName.c_str()));
+       ObjectDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
+       create(model_type);
 }
 
 template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionAdapter()
@@ -54,14 +50,15 @@ template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(Objec
                // If default task type is same as a given one then skip.
                if (_object_detection->getTaskType() == task_type)
                        return;
-
-               _object_detection.reset();
        }
 
+       // if model name is changed by user then reallocate the parser and reload the meta file corresponding to the model name.
+       _config->loadMetaFile(task_type);
+
        if (task_type == ObjectDetectionTaskType::MOBILENET_V1_SSD)
-               _object_detection = make_unique<MobilenetV1Ssd>(task_type);
+               _object_detection = make_unique<MobilenetV1Ssd>(task_type, _config);
        else if (task_type == ObjectDetectionTaskType::MOBILENET_V2_SSD)
-               _object_detection = make_unique<MobilenetV2Ssd>(task_type);
+               _object_detection = make_unique<MobilenetV2Ssd>(task_type, _config);
        else if (task_type == ObjectDetectionTaskType::OD_PLUGIN || task_type == ObjectDetectionTaskType::FD_PLUGIN)
                _object_detection = make_unique<ObjectDetectionExternal>(task_type);
        // TODO.
@@ -93,24 +90,18 @@ void ObjectDetectionAdapter<T, V>::setModelInfo(const char *model_file, const ch
                                                                                                const char *model_name)
 {
        try {
-               create(convertToTaskType(string(model_name)));
+               _config->setUserModel(model_file, meta_file, label_file);
+
+               ObjectDetectionTaskType model_type = convertToTaskType(model_name);
+               create(model_type);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }
 
-       if (model_file)
-               _model_file = model_file;
-       if (meta_file)
-               _meta_file = meta_file;
-       if (label_file)
-               _label_file = label_file;
-
-       if (_model_file.empty() && _meta_file.empty() && _label_file.empty()) {
+       if (!model_file && !meta_file) {
                LOGW("Given model info is invalid so default model info will be used instead.");
                return;
        }
-
-       _object_detection->setUserModel(_model_file, _meta_file, _label_file);
 }
 
 template<typename T, typename V>
@@ -121,7 +112,7 @@ void ObjectDetectionAdapter<T, V>::setEngineInfo(const char *engine_type, const
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::configure()
 {
-       _object_detection->configure(_meta_file_name);
+       _object_detection->configure();
 }
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
diff --git a/mv_machine_learning/object_detection/src/object_detection_config.cpp b/mv_machine_learning/object_detection/src/object_detection_config.cpp
new file mode 100644 (file)
index 0000000..073cfef
--- /dev/null
@@ -0,0 +1,159 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "machine_learning_exception.h"
+#include "mv_object_detection_config.h"
+#include "ObjectDetectionParser.h"
+#include "object_detection_config.h"
+
+using namespace std;
+using namespace MediaVision::Common;
+using namespace mediavision::machine_learning;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+ObjectDetectionConfig::ObjectDetectionConfig()
+{}
+
+void ObjectDetectionConfig::setBackendType(int backend_type)
+{
+       _backendType = backend_type;
+}
+
+void ObjectDetectionConfig::setTargetDeviceType(int device_type)
+{
+       _targetDeviceType = device_type;
+}
+
+const std::string &ObjectDetectionConfig::getDefaultModelName() const
+{
+       return _defaultModelName;
+}
+
+const std::string &ObjectDetectionConfig::getModelFilePath() const
+{
+       return _modelFilePath;
+}
+
+const std::string &ObjectDetectionConfig::getLabelFilePath() const
+{
+       return _modelLabelFilePath;
+}
+
+MetaMap &ObjectDetectionConfig::getInputMetaMap() const
+{
+       return _parser->getInputMetaMap();
+}
+
+MetaMap &ObjectDetectionConfig::getOutputMetaMap() const
+{
+       return _parser->getOutputMetaMap();
+}
+
+int ObjectDetectionConfig::getBackendType() const
+{
+       return _backendType;
+}
+
+int ObjectDetectionConfig::getTargetDeviceType() const
+{
+       return _targetDeviceType;
+}
+
+void ObjectDetectionConfig::setUserModel(const string &model_file, const string &meta_file, const string &label_file)
+{
+       if (!model_file.empty())
+               _modelFilePath = _modelDefaultPath + model_file;
+       if (!meta_file.empty())
+               _modelMetaFilePath = _modelDefaultPath + meta_file;
+       if (!label_file.empty())
+               _modelLabelFilePath = _modelDefaultPath + label_file;
+}
+
+static bool IsJsonFile(const string &fileName)
+{
+       return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
+}
+
+void ObjectDetectionConfig::parseConfigFile(const std::string &configFilePath)
+{
+       auto config = make_unique<EngineConfig>(MV_CONFIG_PATH + configFilePath);
+
+       int ret = config->getStringAttribute(MV_OBJECT_DETECTION_DEFAULT_MODEL_NAME, &_defaultModelName);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get default model name.");
+
+       if (_backendType == MV_INFERENCE_BACKEND_NONE) {
+               ret = config->getIntegerAttribute(MV_OBJECT_DETECTION_BACKEND_TYPE, &_backendType);
+               if (ret != MEDIA_VISION_ERROR_NONE)
+                       throw InvalidOperation("Fail to get backend engine type.");
+       }
+
+       if (_targetDeviceType == MV_INFERENCE_TARGET_DEVICE_NONE) {
+               ret = config->getIntegerAttribute(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE, &_targetDeviceType);
+               if (ret != MEDIA_VISION_ERROR_NONE)
+                       throw InvalidOperation("Fail to get target device type.");
+       }
+
+       ret = config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get model default path");
+
+       ret = config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get model file path");
+
+       _modelFilePath = _modelDefaultPath + _modelFilePath;
+       LOGI("model file path = %s", _modelFilePath.c_str());
+
+       ret = config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get model meta file path");
+
+       if (_modelMetaFilePath.empty())
+               throw InvalidOperation("Model meta file doesn't exist.");
+
+       if (!IsJsonFile(_modelMetaFilePath))
+               throw InvalidOperation("Model meta file should be json");
+
+       _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
+       LOGI("meta file path = %s", _modelMetaFilePath.c_str());
+
+       ret = config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
+       if (ret != MEDIA_VISION_ERROR_NONE)
+               throw InvalidOperation("Fail to get label file path");
+
+       if (_modelLabelFilePath.empty()) {
+               LOGW("Label doesn't exist.");
+               return;
+       }
+
+       _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
+       LOGI("label file path = %s", _modelLabelFilePath.c_str());
+}
+
+void ObjectDetectionConfig::loadMetaFile(ObjectDetectionTaskType task_type)
+{
+       _parser = make_unique<ObjectDetectionParser>();
+       _parser->setTaskType(static_cast<int>(task_type));
+       _parser->load(_modelMetaFilePath);
+}
+
+}
+}
index 08ef647..8084195 100644 (file)
@@ -82,11 +82,6 @@ ObjectDetectionTaskType ObjectDetectionExternal::getTaskType()
        return _object_detection_plugin->getTaskType();
 }
 
-void ObjectDetectionExternal::setUserModel(std::string model_file, std::string meta_file, std::string label_file)
-{
-       _object_detection_plugin->setUserModel(model_file, meta_file, label_file);
-}
-
 void ObjectDetectionExternal::setEngineInfo(std::string engine_type, std::string device_type)
 {
        _object_detection_plugin->setEngineInfo(engine_type, device_type);
@@ -113,9 +108,9 @@ void ObjectDetectionExternal::getDeviceType(const char *engine_type, const unsig
        _object_detection_plugin->getDeviceType(engine_type, device_index, device_type);
 }
 
-void ObjectDetectionExternal::configure(std::string configFile)
+void ObjectDetectionExternal::configure()
 {
-       _object_detection_plugin->configure(configFile);
+       _object_detection_plugin->configure();
 }
 
 void ObjectDetectionExternal::prepare()