mv_machine_learning: refactor ImageSegmentation task group
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Wed, 15 Nov 2023 05:25:38 +0000 (14:25 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 6 Dec 2023 01:36:46 +0000 (10:36 +0900)
[Issue type] code refactoring

Lift parsing dependency from ImageSegmentation class to
ImageSegmentationAdapter and convert ImageSegmentation
into a template class.

Change-Id: I349ac7ba04f7d0193765e8cb72b6ca1a3fb7830d
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
mv_machine_learning/image_segmentation/include/iimage_segmentation.h
mv_machine_learning/image_segmentation/include/image_segmentation.h
mv_machine_learning/image_segmentation/include/image_segmentation_external.h
mv_machine_learning/image_segmentation/include/selfie_segmentation_adapter.h
mv_machine_learning/image_segmentation/src/image_segmentation.cpp
mv_machine_learning/image_segmentation/src/image_segmentation_external.cpp
mv_machine_learning/image_segmentation/src/selfie_segmentation_adapter.cpp

index 1acf9df..bfee8c0 100644 (file)
@@ -31,14 +31,13 @@ public:
        virtual ~IImageSegmentation() {};
 
        virtual void preDestroy() = 0;
-       virtual ImageSegmentationTaskType getTaskType() = 0;
        virtual void setUserModel(std::string model_file, std::string meta_file, std::string label_file) = 0;
        virtual void setEngineInfo(std::string engine_type, std::string device_type) = 0;
        virtual void getNumberOfEngines(unsigned int *number_of_engines) = 0;
        virtual void getEngineType(unsigned int engine_index, char **engine_type) = 0;
        virtual void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) = 0;
        virtual void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) = 0;
-       virtual void configure(std::string configFile) = 0;
+       virtual void configure() = 0;
        virtual void prepare() = 0;
        virtual void perform(mv_source_h &mv_src) = 0;
        virtual void performAsync(ImageSegmentationInput &input) = 0;
index ac5d14f..b56df31 100644 (file)
@@ -30,6 +30,8 @@
 #include "inference_engine_common_impl.h"
 #include "Inference.h"
 #include "image_segmentation_type.h"
+#include "MetaParser.h"
+#include "machine_learning_config.h"
 #include "image_segmentation_parser.h"
 #include "machine_learning_preprocess.h"
 #include "iimage_segmentation.h"
@@ -39,56 +41,43 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class ImageSegmentation : public IImageSegmentation
+template<typename T> class ImageSegmentation : public IImageSegmentation
 {
 private:
-       ImageSegmentationTaskType _task_type { ImageSegmentationTaskType::IMAGE_SEGMENTATION_TASK_NONE };
        std::unique_ptr<AsyncManager<ImageSegmentationResult> > _async_manager;
        ImageSegmentationResult _current_result {};
 
        void loadLabel();
        void getEngineList();
        void getDeviceList(const char *engine_type);
-       template<typename T>
        void preprocess(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo, std::vector<T> &inputVector);
        std::shared_ptr<MetaInfo> getInputMetaInfo();
-       template<typename T> void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo);
-       template<typename T> void performAsync(ImageSegmentationInput &input, std::shared_ptr<MetaInfo> metaInfo);
 
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
-       std::unique_ptr<MediaVision::Common::EngineConfig> _config;
-       std::unique_ptr<MetaParser> _parser;
+       std::shared_ptr<MachineLearningConfig> _config;
        std::vector<std::string> _labels;
        std::vector<std::string> _valid_backends;
        std::vector<std::string> _valid_devices;
        Preprocess _preprocess;
-       std::string _modelFilePath;
-       std::string _modelMetaFilePath;
-       std::string _modelDefaultPath;
-       std::string _modelLabelFilePath;
-       int _backendType {};
-       int _targetDeviceType {};
 
        void getOutputNames(std::vector<std::string> &names);
        void getOutputTensor(std::string target_name, std::vector<float> &tensor);
-       void parseMetaFile(std::string meta_file_name);
-       template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
+       void inference(std::vector<std::vector<T> > &inputVectors);
        virtual ImageSegmentationResult &result() = 0;
 
 public:
-       explicit ImageSegmentation(ImageSegmentationTaskType task_type);
+       explicit ImageSegmentation(std::shared_ptr<MachineLearningConfig> config);
        virtual ~ImageSegmentation() = default;
 
        void preDestroy() override;
-       ImageSegmentationTaskType getTaskType() override;
        void setUserModel(std::string model_file, std::string meta_file, std::string label_file) override;
        void setEngineInfo(std::string engine_type, std::string device_type) override;
        void getNumberOfEngines(unsigned int *number_of_engines) override;
        void getEngineType(unsigned int engine_index, char **engine_type) override;
        void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) override;
        void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) override;
-       void configure(std::string configFile) override;
+       void configure() override;
        void prepare() override;
        void perform(mv_source_h &mv_src) override;
        void performAsync(ImageSegmentationInput &input) override;
index 08c391e..30b9f89 100644 (file)
@@ -43,14 +43,13 @@ public:
        virtual ~ImageSegmentationExternal();
 
        void preDestroy() override;
-       ImageSegmentationTaskType getTaskType() override;
        void setUserModel(std::string model_file, std::string meta_file, std::string label_file) override;
        void setEngineInfo(std::string engine_type, std::string device_type) override;
        void getNumberOfEngines(unsigned int *number_of_engines) override;
        void getEngineType(unsigned int engine_index, char **engine_type) override;
        void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) override;
        void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) override;
-       void configure(std::string configFile) override;
+       void configure() override;
        void prepare() override;
        void perform(mv_source_h &mv_src) override;
        void performAsync(ImageSegmentationInput &input) override;
index 7c1fd2a..86c797c 100644 (file)
@@ -22,6 +22,7 @@
 #include "EngineConfig.h"
 #include "itask.h"
 #include "image_segmentation.h"
+#include "iimage_segmentation.h"
 
 namespace mediavision
 {
@@ -31,14 +32,12 @@ template<typename T, typename V> class ImageSegmentationAdapter : public mediavi
 {
 private:
        std::unique_ptr<IImageSegmentation> _selfie_segmentation;
+       std::shared_ptr<MachineLearningConfig> _config;
        T _source;
-       std::string _model_name;
-       std::string _model_file;
-       std::string _meta_file;
-       std::string _label_file;
-       const std::string _meta_file_name = "selfie_segmentation.json";
+       const std::string _config_file_name = "selfie_segmentation.json";
 
-       void create(ImageSegmentationTaskType task_type);
+       void create(const std::string &model_name);
+       template<typename U> void create(ImageSegmentationTaskType task_type);
        ImageSegmentationTaskType convertToTaskType(std::string model_name);
 
 public:
index 8709d53..2868954 100644 (file)
@@ -34,16 +34,12 @@ namespace mediavision
 {
 namespace machine_learning
 {
-ImageSegmentation::ImageSegmentation(ImageSegmentationTaskType task_type)
-               : _task_type(task_type)
-               , _backendType(MV_INFERENCE_BACKEND_NONE)
-               , _targetDeviceType(MV_INFERENCE_TARGET_DEVICE_NONE)
+template<typename T> ImageSegmentation<T>::ImageSegmentation(shared_ptr<MachineLearningConfig> config) : _config(config)
 {
        _inference = make_unique<Inference>();
-       _parser = make_unique<ImageSegmentationParser>();
 }
 
-void ImageSegmentation::preDestroy()
+template<typename T> void ImageSegmentation<T>::preDestroy()
 {
        if (!_async_manager)
                return;
@@ -51,12 +47,7 @@ void ImageSegmentation::preDestroy()
        _async_manager->stop();
 }
 
-ImageSegmentationTaskType ImageSegmentation::getTaskType()
-{
-       return _task_type;
-}
-
-void ImageSegmentation::getEngineList()
+template<typename T> void ImageSegmentation<T>::getEngineList()
 {
        for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
                auto backend = _inference->getSupportedInferenceBackend(idx);
@@ -68,7 +59,7 @@ void ImageSegmentation::getEngineList()
        }
 }
 
-void ImageSegmentation::getDeviceList(const char *engine_type)
+template<typename T> void ImageSegmentation<T>::getDeviceList(const char *engine_type)
 {
        // TODO. add device types available for a given engine type later.
        //       In default, cpu and gpu only.
@@ -76,26 +67,33 @@ void ImageSegmentation::getDeviceList(const char *engine_type)
        _valid_devices.push_back("gpu");
 }
 
-void ImageSegmentation::setEngineInfo(std::string engine_type, std::string device_type)
+template<typename T>
+void ImageSegmentation<T>::setUserModel(std::string model_file, std::string meta_file, std::string label_file)
+{}
+
+template<typename T>
+void ImageSegmentation<T>::setEngineInfo(std::string engine_type_name, std::string device_type_name)
 {
-       if (engine_type.empty() || device_type.empty())
+       if (engine_type_name.empty() || device_type_name.empty())
                throw InvalidParameter("Invalid engine info.");
 
-       transform(engine_type.begin(), engine_type.end(), engine_type.begin(), ::toupper);
-       transform(device_type.begin(), device_type.end(), device_type.begin(), ::toupper);
-
-       _backendType = GetBackendType(engine_type);
-       _targetDeviceType = GetDeviceType(device_type);
+       transform(engine_type_name.begin(), engine_type_name.end(), engine_type_name.begin(), ::toupper);
+       transform(device_type_name.begin(), device_type_name.end(), device_type_name.begin(), ::toupper);
 
-       LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type.c_str(), GetBackendType(engine_type),
-                device_type.c_str(), GetDeviceType(device_type));
+       int engine_type = GetBackendType(engine_type_name);
+       int device_type = GetDeviceType(device_type_name);
 
-       if (_backendType == MEDIA_VISION_ERROR_INVALID_PARAMETER ||
-               _targetDeviceType == MEDIA_VISION_ERROR_INVALID_PARAMETER)
+       if (engine_type == MEDIA_VISION_ERROR_INVALID_PARAMETER || device_type == MEDIA_VISION_ERROR_INVALID_PARAMETER)
                throw InvalidParameter("backend or target device type not found.");
+
+       _config->setBackendType(engine_type);
+       _config->setTargetDeviceType(device_type);
+
+       LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type_name.c_str(), engine_type,
+                device_type_name.c_str(), device_type);
 }
 
-void ImageSegmentation::getNumberOfEngines(unsigned int *number_of_engines)
+template<typename T> void ImageSegmentation<T>::getNumberOfEngines(unsigned int *number_of_engines)
 {
        if (!_valid_backends.empty()) {
                *number_of_engines = _valid_backends.size();
@@ -106,7 +104,7 @@ void ImageSegmentation::getNumberOfEngines(unsigned int *number_of_engines)
        *number_of_engines = _valid_backends.size();
 }
 
-void ImageSegmentation::getEngineType(unsigned int engine_index, char **engine_type)
+template<typename T> void ImageSegmentation<T>::getEngineType(unsigned int engine_index, char **engine_type)
 {
        if (!_valid_backends.empty()) {
                if (_valid_backends.size() <= engine_index)
@@ -124,7 +122,8 @@ void ImageSegmentation::getEngineType(unsigned int engine_index, char **engine_t
        *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
 }
 
-void ImageSegmentation::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
+template<typename T>
+void ImageSegmentation<T>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
 {
        if (!_valid_devices.empty()) {
                *number_of_devices = _valid_devices.size();
@@ -135,7 +134,8 @@ void ImageSegmentation::getNumberOfDevices(const char *engine_type, unsigned int
        *number_of_devices = _valid_devices.size();
 }
 
-void ImageSegmentation::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
+template<typename T>
+void ImageSegmentation<T>::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
 {
        if (!_valid_devices.empty()) {
                if (_valid_devices.size() <= device_index)
@@ -153,27 +153,18 @@ void ImageSegmentation::getDeviceType(const char *engine_type, const unsigned in
        *device_type = const_cast<char *>(_valid_devices[device_index].data());
 }
 
-void ImageSegmentation::setUserModel(string model_file, string meta_file, string label_file)
-{
-       _modelFilePath = model_file;
-       _modelMetaFilePath = meta_file;
-       _modelLabelFilePath = label_file;
-}
-
-static bool IsJsonFile(const string &fileName)
+template<typename T> void ImageSegmentation<T>::loadLabel()
 {
-       return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
-}
+       if (_config->getLabelFilePath().empty())
+               return;
 
-void ImageSegmentation::loadLabel()
-{
        ifstream readFile;
 
        _labels.clear();
-       readFile.open(_modelLabelFilePath.c_str());
+       readFile.open(_config->getLabelFilePath().c_str());
 
        if (readFile.fail())
-               throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
+               throw InvalidOperation("Fail to open " + _config->getLabelFilePath() + " file.");
 
        string line;
 
@@ -183,88 +174,26 @@ void ImageSegmentation::loadLabel()
        readFile.close();
 }
 
-void ImageSegmentation::parseMetaFile(string meta_file_name)
+template<typename T> void ImageSegmentation<T>::configure()
 {
-       int ret = MEDIA_VISION_ERROR_NONE;
-       _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + meta_file_name);
-
-       if (_backendType == MV_INFERENCE_BACKEND_NONE) {
-               ret = _config->getIntegerAttribute(MV_IMAGE_SEGMENTATION_BACKEND_TYPE, &_backendType);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get backend engine type.");
-       }
-
-       if (_targetDeviceType == MV_INFERENCE_TARGET_DEVICE_NONE) {
-               ret = _config->getIntegerAttribute(MV_IMAGE_SEGMENTATION_TARGET_DEVICE_TYPE, &_targetDeviceType);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get target device type.");
-       }
-
-       ret = _config->getStringAttribute(MV_IMAGE_SEGMENTATION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get model default path");
-
-       if (_modelFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_IMAGE_SEGMENTATION_MODEL_FILE_PATH, &_modelFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get model file path");
-       }
-
-       _modelFilePath = _modelDefaultPath + _modelFilePath;
-       LOGI("model file path = %s", _modelFilePath.c_str());
-
-       if (_modelMetaFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_IMAGE_SEGMENTATION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get model meta file path");
-
-               if (_modelMetaFilePath.empty())
-                       throw InvalidOperation("Model meta file doesn't exist.");
-
-               if (!IsJsonFile(_modelMetaFilePath))
-                       throw InvalidOperation("Model meta file should be json");
-       }
-
-       _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
-       LOGI("meta file path = %s", _modelMetaFilePath.c_str());
-
-       _parser->load(_modelMetaFilePath);
-
-       if (_modelLabelFilePath.empty()) {
-               ret = _config->getStringAttribute(MV_IMAGE_SEGMENTATION_LABEL_FILE_NAME, &_modelLabelFilePath);
-               if (ret != MEDIA_VISION_ERROR_NONE)
-                       throw InvalidOperation("Fail to get label file path");
-
-               if (_modelLabelFilePath.empty())
-                       throw InvalidOperation("Model label file doesn't exist.");
-       }
-
-       _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
-       LOGI("label file path = %s", _modelLabelFilePath.c_str());
-
        loadLabel();
-}
 
-void ImageSegmentation::configure(string configFile)
-{
-       parseMetaFile(configFile);
-
-       int ret = _inference->bind(_backendType, _targetDeviceType);
+       int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to bind a backend engine.");
 }
 
-void ImageSegmentation::prepare()
+template<typename T> void ImageSegmentation<T>::prepare()
 {
-       int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
+       int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to configure input tensor info from meta file.");
 
-       ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
+       ret = _inference->configureOutputMetaInfo(_config->getOutputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to configure output tensor info from meta file.");
 
-       _inference->configureModelFiles("", _modelFilePath, "");
+       _inference->configureModelFiles("", _config->getModelFilePath(), "");
 
        // Request to load model files to a backend engine.
        ret = _inference->load();
@@ -272,7 +201,7 @@ void ImageSegmentation::prepare()
                throw InvalidOperation("Fail to load model files.");
 }
 
-shared_ptr<MetaInfo> ImageSegmentation::getInputMetaInfo()
+template<typename T> shared_ptr<MetaInfo> ImageSegmentation<T>::getInputMetaInfo()
 {
        TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
        IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
@@ -284,11 +213,11 @@ shared_ptr<MetaInfo> ImageSegmentation::getInputMetaInfo()
        auto tensor_buffer_iter = tensor_info_map.begin();
 
        // Get the meta information corresponding to a given input tensor name.
-       return _parser->getInputMetaMap()[tensor_buffer_iter->first];
+       return _config->getInputMetaMap()[tensor_buffer_iter->first];
 }
 
 template<typename T>
-void ImageSegmentation::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
+void ImageSegmentation<T>::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
 {
        LOGI("ENTER");
 
@@ -320,7 +249,7 @@ void ImageSegmentation::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> met
        LOGI("LEAVE");
 }
 
-template<typename T> void ImageSegmentation::inference(vector<vector<T> > &inputVectors)
+template<typename T> void ImageSegmentation<T>::inference(vector<vector<T> > &inputVectors)
 {
        LOGI("ENTER");
 
@@ -331,37 +260,24 @@ template<typename T> void ImageSegmentation::inference(vector<vector<T> > &input
        LOGI("LEAVE");
 }
 
-template<typename T> void ImageSegmentation::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void ImageSegmentation<T>::perform(mv_source_h &mv_src)
 {
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(mv_src, metaInfo, inputVector);
+       preprocess(mv_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
-       inference<T>(inputVectors);
-
-       // TODO. Update operation status here.
+       inference(inputVectors);
 }
 
-void ImageSegmentation::perform(mv_source_h &mv_src)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
-               perform<unsigned char>(mv_src, metaInfo);
-       else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
-               perform<float>(mv_src, metaInfo);
-       else
-               throw InvalidOperation("Invalid model data type.");
-}
-
-template<typename T> void ImageSegmentation::performAsync(ImageSegmentationInput &input, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void ImageSegmentation<T>::performAsync(ImageSegmentationInput &input)
 {
        if (!_async_manager) {
                _async_manager = make_unique<AsyncManager<ImageSegmentationResult> >([this]() {
                        AsyncInputQueue<T> inputQueue = _async_manager->popFromInput<T>();
 
-                       inference<T>(inputQueue.inputs);
+                       inference(inputQueue.inputs);
 
                        ImageSegmentationResult &resultQueue = result();
 
@@ -370,30 +286,16 @@ template<typename T> void ImageSegmentation::performAsync(ImageSegmentationInput
                });
        }
 
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(input.inference_src, metaInfo, inputVector);
+       preprocess(input.inference_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
        _async_manager->push(inputVectors);
 }
 
-void ImageSegmentation::performAsync(ImageSegmentationInput &input)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
-               performAsync<unsigned char>(input, metaInfo);
-       } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
-               performAsync<float>(input, metaInfo);
-               // TODO
-       } else {
-               throw InvalidOperation("Invalid model data type.");
-       }
-}
-
-ImageSegmentationResult &ImageSegmentation::getOutput()
+template<typename T> ImageSegmentationResult &ImageSegmentation<T>::getOutput()
 {
        if (_async_manager) {
                if (!_async_manager->isWorking())
@@ -409,12 +311,12 @@ ImageSegmentationResult &ImageSegmentation::getOutput()
        return _current_result;
 }
 
-ImageSegmentationResult &ImageSegmentation::getOutputCache()
+template<typename T> ImageSegmentationResult &ImageSegmentation<T>::getOutputCache()
 {
        return _current_result;
 }
 
-void ImageSegmentation::getOutputNames(vector<string> &names)
+template<typename T> void ImageSegmentation<T>::getOutputNames(vector<string> &names)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
        IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
@@ -423,7 +325,7 @@ void ImageSegmentation::getOutputNames(vector<string> &names)
                names.push_back(it->first);
 }
 
-void ImageSegmentation::getOutputTensor(string target_name, vector<float> &tensor)
+template<typename T> void ImageSegmentation<T>::getOutputTensor(string target_name, vector<float> &tensor)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
 
@@ -436,18 +338,8 @@ void ImageSegmentation::getOutputTensor(string target_name, vector<float> &tenso
        copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
 }
 
-template void ImageSegmentation::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                  vector<float> &inputVector);
-template void ImageSegmentation::inference<float>(vector<vector<float> > &inputVectors);
-template void ImageSegmentation::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void ImageSegmentation::performAsync<float>(ImageSegmentationInput &input, shared_ptr<MetaInfo> metaInfo);
-
-template void ImageSegmentation::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                                  vector<unsigned char> &inputVector);
-template void ImageSegmentation::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
-template void ImageSegmentation::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void ImageSegmentation::performAsync<unsigned char>(ImageSegmentationInput &input,
-                                                                                                                        shared_ptr<MetaInfo> metaInfo);
+template class ImageSegmentation<unsigned char>;
+template class ImageSegmentation<float>;
 
 }
 }
\ No newline at end of file
index 24ac9f5..d88ec01 100644 (file)
@@ -77,11 +77,6 @@ void ImageSegmentationExternal::preDestroy()
        _image_segmentation_plugin->preDestroy();
 }
 
-ImageSegmentationTaskType ImageSegmentationExternal::getTaskType()
-{
-       return _image_segmentation_plugin->getTaskType();
-}
-
 void ImageSegmentationExternal::setUserModel(std::string model_file, std::string meta_file, std::string label_file)
 {
        _image_segmentation_plugin->setUserModel(model_file, meta_file, label_file);
@@ -113,9 +108,9 @@ void ImageSegmentationExternal::getDeviceType(const char *engine_type, const uns
        _image_segmentation_plugin->getDeviceType(engine_type, device_index, device_type);
 }
 
-void ImageSegmentationExternal::configure(std::string configFile)
+void ImageSegmentationExternal::configure()
 {
-       _image_segmentation_plugin->configure(configFile);
+       _image_segmentation_plugin->configure();
 }
 
 void ImageSegmentationExternal::prepare()
index f920df9..2c97801 100644 (file)
@@ -30,15 +30,10 @@ namespace machine_learning
 {
 template<typename T, typename V> ImageSegmentationAdapter<T, V>::ImageSegmentationAdapter() : _source()
 {
-       auto config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + _meta_file_name);
+       _config = make_shared<MachineLearningConfig>();
+       _config->parseConfigFile(_config_file_name);
 
-       string defaultModelName;
-
-       int ret = config->getStringAttribute(MV_IMAGE_SEGMENTATION_DEFAULT_MODEL_NAME, &defaultModelName);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get default model name.");
-
-       create(convertToTaskType(defaultModelName.c_str()));
+       create(_config->getDefaultModelName());
 }
 
 template<typename T, typename V> ImageSegmentationAdapter<T, V>::~ImageSegmentationAdapter()
@@ -46,20 +41,36 @@ template<typename T, typename V> ImageSegmentationAdapter<T, V>::~ImageSegmentat
        _selfie_segmentation->preDestroy();
 }
 
-template<typename T, typename V> void ImageSegmentationAdapter<T, V>::create(ImageSegmentationTaskType task_type)
+template<typename T, typename V>
+template<typename U>
+void ImageSegmentationAdapter<T, V>::create(ImageSegmentationTaskType task_type)
 {
-       // If a concrete class object created already exists, reset the object
-       // so that other concrete class object can be created again according to a given task_type.
-       if (_selfie_segmentation) {
-               // If default task type is same as a given one then skip.
-               if (_selfie_segmentation->getTaskType() == task_type)
-                       return;
-
-               _selfie_segmentation.reset();
+       switch (task_type) {
+       case ImageSegmentationTaskType::SELFIE_SEGMENTATION_PLUGIN:
+               _selfie_segmentation = make_unique<ImageSegmentationExternal>(task_type);
+               break;
+       default:
+               throw InvalidOperation("Invalid image segmentation task type.");
        }
+       // TODO.
+}
 
-       if (task_type == ImageSegmentationTaskType::SELFIE_SEGMENTATION_PLUGIN)
-               _selfie_segmentation = make_unique<ImageSegmentationExternal>(task_type);
+template<typename T, typename V> void ImageSegmentationAdapter<T, V>::create(const string &model_name)
+{
+       ImageSegmentationTaskType task_type = convertToTaskType(model_name);
+       _config->loadMetaFile(make_unique<ImageSegmentationParser>(static_cast<int>(task_type)));
+       mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
+
+       switch (dataType) {
+       case MV_INFERENCE_DATA_UINT8:
+               create<unsigned char>(task_type);
+               break;
+       case MV_INFERENCE_DATA_FLOAT32:
+               create<float>(task_type);
+               break;
+       default:
+               throw InvalidOperation("Invalid image segmentation data type.");
+       }
 }
 
 template<typename T, typename V>
@@ -81,24 +92,18 @@ void ImageSegmentationAdapter<T, V>::setModelInfo(const char *model_file, const
                                                                                                  const char *model_name)
 {
        try {
-               create(convertToTaskType(string(model_name)));
+               _config->setUserModel(model_file, meta_file, label_file);
+               create(model_name);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }
 
-       if (model_file)
-               _model_file = model_file;
-       if (meta_file)
-               _meta_file = meta_file;
-       if (label_file)
-               _label_file = label_file;
-
-       if (_model_file.empty() && _meta_file.empty() && _label_file.empty()) {
+       if (!model_file && !meta_file) {
                LOGW("Given model info is invalid so default model info will be used instead.");
                return;
        }
 
-       _selfie_segmentation->setUserModel(_model_file, _meta_file, _label_file);
+       _selfie_segmentation->setUserModel(model_file, meta_file, label_file);
 }
 
 template<typename T, typename V>
@@ -109,7 +114,7 @@ void ImageSegmentationAdapter<T, V>::setEngineInfo(const char *engine_type, cons
 
 template<typename T, typename V> void ImageSegmentationAdapter<T, V>::configure()
 {
-       _selfie_segmentation->configure(_meta_file_name);
+       _selfie_segmentation->configure();
 }
 
 template<typename T, typename V>