mv_machine_learning: convert LandmarkDetection class into template class
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Tue, 14 Nov 2023 09:39:21 +0000 (18:39 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 6 Dec 2023 01:36:46 +0000 (10:36 +0900)
[Issue type] code refactoring

Change-Id: Ib4a040c5a6b97fd19ab3696247e2b9df11ce9ea0
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
mv_machine_learning/landmark_detection/include/facial_landmark_adapter.h
mv_machine_learning/landmark_detection/include/fld_tweak_cnn.h
mv_machine_learning/landmark_detection/include/ilandmark_detection.h [new file with mode: 0644]
mv_machine_learning/landmark_detection/include/landmark_detection.h
mv_machine_learning/landmark_detection/include/pld_cpm.h
mv_machine_learning/landmark_detection/include/pose_landmark_adapter.h
mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp
mv_machine_learning/landmark_detection/src/fld_tweak_cnn.cpp
mv_machine_learning/landmark_detection/src/landmark_detection.cpp
mv_machine_learning/landmark_detection/src/pld_cpm.cpp
mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp

index 98daae4..b4b0f13 100644 (file)
@@ -23,6 +23,7 @@
 #include "itask.h"
 #include "machine_learning_config.h"
 #include "fld_tweak_cnn.h"
+#include "ilandmark_detection.h"
 
 namespace mediavision
 {
@@ -31,12 +32,13 @@ namespace machine_learning
 template<typename T, typename V> class FacialLandmarkAdapter : public mediavision::common::ITask<T, V>
 {
 private:
-       std::unique_ptr<LandmarkDetection> _landmark_detection;
+       std::unique_ptr<ILandmarkDetection> _landmark_detection;
        std::shared_ptr<MachineLearningConfig> _config;
        T _source;
        const std::string _config_file_name = "facial_landmark.json";
 
-       void create(LandmarkDetectionTaskType task_type);
+       void create(const std::string &model_name);
+       template<typename U> void create(LandmarkDetectionTaskType task_type);
        LandmarkDetectionTaskType convertToTaskType(std::string model_name);
 
 public:
index 42dc2c1..08d4b67 100644 (file)
@@ -29,8 +29,12 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class FldTweakCnn : public LandmarkDetection
+template<typename T> class FldTweakCnn : public LandmarkDetection<T>
 {
+       using LandmarkDetection<T>::_config;
+       using LandmarkDetection<T>::_preprocess;
+       using LandmarkDetection<T>::_inference;
+
 private:
        LandmarkDetectionResult _result;
 
diff --git a/mv_machine_learning/landmark_detection/include/ilandmark_detection.h b/mv_machine_learning/landmark_detection/include/ilandmark_detection.h
new file mode 100644 (file)
index 0000000..ca05d7e
--- /dev/null
@@ -0,0 +1,50 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ILANDMARK_DETECTION_H__
+#define __ILANDMARK_DETECTION_H__
+
+#include <mv_common.h>
+
+#include "landmark_detection_type.h"
+
+namespace mediavision
+{
+namespace machine_learning
+{
+class ILandmarkDetection
+{
+public:
+       virtual ~ILandmarkDetection() {};
+
+       virtual void preDestroy() = 0;
+       virtual LandmarkDetectionTaskType getTaskType() = 0;
+       virtual void setEngineInfo(std::string engine_type, std::string device_type) = 0;
+       virtual void getNumberOfEngines(unsigned int *number_of_engines) = 0;
+       virtual void getEngineType(unsigned int engine_index, char **engine_type) = 0;
+       virtual void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) = 0;
+       virtual void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) = 0;
+       virtual void configure() = 0;
+       virtual void prepare() = 0;
+       virtual void perform(mv_source_h &mv_src) = 0;
+       virtual void performAsync(LandmarkDetectionInput &input) = 0;
+       virtual LandmarkDetectionResult &getOutput() = 0;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index 1556d94..6f55f0e 100644 (file)
 #include "machine_learning_config.h"
 #include "machine_learning_preprocess.h"
 #include "async_manager.h"
+#include "ilandmark_detection.h"
 
 namespace mediavision
 {
 namespace machine_learning
 {
-class LandmarkDetection
+template<typename T> class LandmarkDetection : public ILandmarkDetection
 {
 private:
        std::unique_ptr<AsyncManager<LandmarkDetectionResult> > _async_manager;
@@ -46,11 +47,8 @@ private:
        void getEngineList();
        void getDeviceList(const char *engine_type);
 
-       template<typename T>
        void preprocess(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo, std::vector<T> &inputVector);
        std::shared_ptr<MetaInfo> getInputMetaInfo();
-       template<typename T> void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo);
-       template<typename T> void performAsync(LandmarkDetectionInput &input, std::shared_ptr<MetaInfo> metaInfo);
 
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
@@ -62,7 +60,7 @@ protected:
 
        void getOutputNames(std::vector<std::string> &names);
        void getOutputTensor(std::string target_name, std::vector<float> &tensor);
-       template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
+       void inference(std::vector<std::vector<T> > &inputVectors);
        virtual LandmarkDetectionResult &result() = 0;
 
 public:
@@ -70,7 +68,6 @@ public:
        virtual ~LandmarkDetection() = default;
        void preDestroy();
        LandmarkDetectionTaskType getTaskType();
-       void setUserModel(std::string model_file, std::string meta_file, std::string label_file);
        void setEngineInfo(std::string engine_type_name, std::string device_type_name);
        void getNumberOfEngines(unsigned int *number_of_engines);
        void getEngineType(unsigned int engine_index, char **engine_type);
index 7ce71b6..993e975 100644 (file)
@@ -29,8 +29,11 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class PldCpm : public LandmarkDetection
+template<typename T> class PldCpm : public LandmarkDetection<T>
 {
+       using LandmarkDetection<T>::_config;
+       using LandmarkDetection<T>::_preprocess;
+
 private:
        LandmarkDetectionResult _result;
 
index 4d21c3f..717ae58 100644 (file)
@@ -23,6 +23,7 @@
 #include "itask.h"
 #include "machine_learning_config.h"
 #include "pld_cpm.h"
+#include "ilandmark_detection.h"
 
 namespace mediavision
 {
@@ -31,12 +32,13 @@ namespace machine_learning
 template<typename T, typename V> class PoseLandmarkAdapter : public mediavision::common::ITask<T, V>
 {
 private:
-       std::unique_ptr<LandmarkDetection> _landmark_detection;
+       std::unique_ptr<ILandmarkDetection> _landmark_detection;
        std::shared_ptr<MachineLearningConfig> _config;
        T _source;
        const std::string _config_file_name = "pose_landmark.json";
 
-       void create(LandmarkDetectionTaskType task_type);
+       void create(const std::string &model_name);
+       template<typename U> void create(LandmarkDetectionTaskType task_type);
        LandmarkDetectionTaskType convertToTaskType(std::string model_name);
 
 public:
index 370375c..ccadda6 100644 (file)
@@ -32,8 +32,7 @@ template<typename T, typename V> FacialLandmarkAdapter<T, V>::FacialLandmarkAdap
        _config = make_shared<MachineLearningConfig>();
        _config->parseConfigFile(_config_file_name);
 
-       LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
-       create(model_type);
+       create(_config->getDefaultModelName());
 }
 
 template<typename T, typename V> FacialLandmarkAdapter<T, V>::~FacialLandmarkAdapter()
@@ -41,16 +40,35 @@ template<typename T, typename V> FacialLandmarkAdapter<T, V>::~FacialLandmarkAda
        _landmark_detection->preDestroy();
 }
 
-template<typename T, typename V> void FacialLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
-{
-       if (_landmark_detection) {
-               // If current task type is same as a given one then skip.
-               if (_landmark_detection->getTaskType() == task_type)
-                       return;
+template<typename T, typename V>
+template<typename U>
+void FacialLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
+{
+       switch (task_type) {
+       case LandmarkDetectionTaskType::FLD_TWEAK_CNN:
+               _landmark_detection = make_unique<FldTweakCnn<U> >(task_type, _config);
+               break;
+       default:
+               throw InvalidOperation("Invalid landmark detection task type.");
        }
+}
 
-       if (task_type == LandmarkDetectionTaskType::FLD_TWEAK_CNN)
-               _landmark_detection = make_unique<FldTweakCnn>(task_type, _config);
+template<typename T, typename V> void FacialLandmarkAdapter<T, V>::create(const string &model_name)
+{
+       LandmarkDetectionTaskType task_type = convertToTaskType(model_name);
+       _config->loadMetaFile(make_unique<LandmarkDetectionParser>(static_cast<int>(task_type)));
+       mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
+
+       switch (dataType) {
+       case MV_INFERENCE_DATA_UINT8:
+               create<unsigned char>(task_type);
+               break;
+       case MV_INFERENCE_DATA_FLOAT32:
+               create<float>(task_type);
+               break;
+       default:
+               throw InvalidOperation("Invalid landmark detection data type.");
+       }
 }
 
 template<typename T, typename V>
@@ -74,9 +92,7 @@ void FacialLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const cha
 {
        try {
                _config->setUserModel(model_file, meta_file, label_file);
-
-               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
-               create(model_type);
+               create(model_name);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }
index a4b922f..1fdecec 100644 (file)
@@ -30,14 +30,15 @@ namespace mediavision
 {
 namespace machine_learning
 {
-FldTweakCnn::FldTweakCnn(LandmarkDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
-               : LandmarkDetection(task_type, config), _result()
+template<typename T>
+FldTweakCnn<T>::FldTweakCnn(LandmarkDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
+               : LandmarkDetection<T>(task_type, config), _result()
 {}
 
-FldTweakCnn::~FldTweakCnn()
+template<typename T> FldTweakCnn<T>::~FldTweakCnn()
 {}
 
-LandmarkDetectionResult &FldTweakCnn::result()
+template<typename T> LandmarkDetectionResult &FldTweakCnn<T>::result()
 {
        constexpr static unsigned int numberOfLandmarks = 5;
 
@@ -47,7 +48,7 @@ LandmarkDetectionResult &FldTweakCnn::result()
 
        vector<string> names;
 
-       LandmarkDetection::getOutputNames(names);
+       LandmarkDetection<T>::getOutputNames(names);
 
        auto scoreMetaInfo = _config->getOutputMetaMap().at(names[0]);
        auto decodingLandmark =
@@ -71,7 +72,7 @@ LandmarkDetectionResult &FldTweakCnn::result()
 
        vector<float> score_tensor;
 
-       LandmarkDetection::getOutputTensor(names[0], score_tensor);
+       LandmarkDetection<T>::getOutputTensor(names[0], score_tensor);
 
        // Calculate the ratio[A] between the original image size and the input tensor size.
        double width_ratio = ori_src_width / input_tensor_width;
@@ -92,5 +93,8 @@ LandmarkDetectionResult &FldTweakCnn::result()
        return _result;
 }
 
+template class FldTweakCnn<unsigned char>;
+template class FldTweakCnn<float>;
+
 }
 }
\ No newline at end of file
index 119e439..bc2bb59 100644 (file)
@@ -34,13 +34,14 @@ namespace mediavision
 {
 namespace machine_learning
 {
-LandmarkDetection::LandmarkDetection(LandmarkDetectionTaskType task_type, shared_ptr<MachineLearningConfig> config)
+template<typename T>
+LandmarkDetection<T>::LandmarkDetection(LandmarkDetectionTaskType task_type, shared_ptr<MachineLearningConfig> config)
                : _task_type(task_type), _config(config)
 {
        _inference = make_unique<Inference>();
 }
 
-void LandmarkDetection::preDestroy()
+template<typename T> void LandmarkDetection<T>::preDestroy()
 {
        if (!_async_manager)
                return;
@@ -48,12 +49,12 @@ void LandmarkDetection::preDestroy()
        _async_manager->stop();
 }
 
-LandmarkDetectionTaskType LandmarkDetection::getTaskType()
+template<typename T> LandmarkDetectionTaskType LandmarkDetection<T>::getTaskType()
 {
        return _task_type;
 }
 
-void LandmarkDetection::getEngineList()
+template<typename T> void LandmarkDetection<T>::getEngineList()
 {
        for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
                auto backend = _inference->getSupportedInferenceBackend(idx);
@@ -65,7 +66,7 @@ void LandmarkDetection::getEngineList()
        }
 }
 
-void LandmarkDetection::getDeviceList(const char *engine_type)
+template<typename T> void LandmarkDetection<T>::getDeviceList(const char *engine_type)
 {
        // TODO. add device types available for a given engine type later.
        //       In default, cpu and gpu only.
@@ -73,7 +74,7 @@ void LandmarkDetection::getDeviceList(const char *engine_type)
        _valid_devices.push_back("gpu");
 }
 
-void LandmarkDetection::setEngineInfo(string engine_type_name, string device_type_name)
+template<typename T> void LandmarkDetection<T>::setEngineInfo(string engine_type_name, string device_type_name)
 {
        if (engine_type_name.empty() || device_type_name.empty())
                throw InvalidParameter("Invalid engine info.");
@@ -94,7 +95,7 @@ void LandmarkDetection::setEngineInfo(string engine_type_name, string device_typ
                 device_type_name.c_str(), device_type);
 }
 
-void LandmarkDetection::getNumberOfEngines(unsigned int *number_of_engines)
+template<typename T> void LandmarkDetection<T>::getNumberOfEngines(unsigned int *number_of_engines)
 {
        if (!_valid_backends.empty()) {
                *number_of_engines = _valid_backends.size();
@@ -105,7 +106,7 @@ void LandmarkDetection::getNumberOfEngines(unsigned int *number_of_engines)
        *number_of_engines = _valid_backends.size();
 }
 
-void LandmarkDetection::getEngineType(unsigned int engine_index, char **engine_type)
+template<typename T> void LandmarkDetection<T>::getEngineType(unsigned int engine_index, char **engine_type)
 {
        if (!_valid_backends.empty()) {
                if (_valid_backends.size() <= engine_index)
@@ -123,7 +124,8 @@ void LandmarkDetection::getEngineType(unsigned int engine_index, char **engine_t
        *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
 }
 
-void LandmarkDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
+template<typename T>
+void LandmarkDetection<T>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
 {
        if (!_valid_devices.empty()) {
                *number_of_devices = _valid_devices.size();
@@ -134,7 +136,8 @@ void LandmarkDetection::getNumberOfDevices(const char *engine_type, unsigned int
        *number_of_devices = _valid_devices.size();
 }
 
-void LandmarkDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
+template<typename T>
+void LandmarkDetection<T>::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
 {
        if (!_valid_devices.empty()) {
                if (_valid_devices.size() <= device_index)
@@ -152,7 +155,7 @@ void LandmarkDetection::getDeviceType(const char *engine_type, const unsigned in
        *device_type = const_cast<char *>(_valid_devices[device_index].data());
 }
 
-void LandmarkDetection::loadLabel()
+template<typename T> void LandmarkDetection<T>::loadLabel()
 {
        if (_config->getLabelFilePath().empty())
                return;
@@ -173,9 +176,8 @@ void LandmarkDetection::loadLabel()
        readFile.close();
 }
 
-void LandmarkDetection::configure()
+template<typename T> void LandmarkDetection<T>::configure()
 {
-       _config->loadMetaFile(make_unique<LandmarkDetectionParser>(static_cast<int>(_task_type)));
        loadLabel();
 
        int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
@@ -183,7 +185,7 @@ void LandmarkDetection::configure()
                throw InvalidOperation("Fail to bind a backend engine.");
 }
 
-void LandmarkDetection::prepare()
+template<typename T> void LandmarkDetection<T>::prepare()
 {
        int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
@@ -201,7 +203,7 @@ void LandmarkDetection::prepare()
                throw InvalidOperation("Fail to load model files.");
 }
 
-shared_ptr<MetaInfo> LandmarkDetection::getInputMetaInfo()
+template<typename T> shared_ptr<MetaInfo> LandmarkDetection<T>::getInputMetaInfo()
 {
        TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
        IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
@@ -217,7 +219,7 @@ shared_ptr<MetaInfo> LandmarkDetection::getInputMetaInfo()
 }
 
 template<typename T>
-void LandmarkDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
+void LandmarkDetection<T>::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
 {
        LOGI("ENTER");
 
@@ -249,7 +251,7 @@ void LandmarkDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> met
        LOGI("LEAVE");
 }
 
-template<typename T> void LandmarkDetection::inference(vector<vector<T> > &inputVectors)
+template<typename T> void LandmarkDetection<T>::inference(vector<vector<T> > &inputVectors)
 {
        LOGI("ENTER");
 
@@ -260,30 +262,18 @@ template<typename T> void LandmarkDetection::inference(vector<vector<T> > &input
        LOGI("LEAVE");
 }
 
-template<typename T> void LandmarkDetection::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void LandmarkDetection<T>::perform(mv_source_h &mv_src)
 {
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(mv_src, metaInfo, inputVector);
+       preprocess(mv_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
-       inference<T>(inputVectors);
-}
-
-void LandmarkDetection::perform(mv_source_h &mv_src)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
-               perform<unsigned char>(mv_src, metaInfo);
-       else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
-               perform<float>(mv_src, metaInfo);
-       else
-               throw InvalidOperation("Invalid model data type.");
+       inference(inputVectors);
 }
 
-LandmarkDetectionResult &LandmarkDetection::getOutput()
+template<typename T> LandmarkDetectionResult &LandmarkDetection<T>::getOutput()
 {
        if (_async_manager) {
                if (!_async_manager->isWorking())
@@ -299,7 +289,7 @@ LandmarkDetectionResult &LandmarkDetection::getOutput()
        return _current_result;
 }
 
-template<typename T> void LandmarkDetection::performAsync(LandmarkDetectionInput &input, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void LandmarkDetection<T>::performAsync(LandmarkDetectionInput &input)
 {
        if (!_async_manager) {
                _async_manager = make_unique<AsyncManager<LandmarkDetectionResult> >([this]() {
@@ -307,7 +297,7 @@ template<typename T> void LandmarkDetection::performAsync(LandmarkDetectionInput
 
                        LOGD("Poped input frame number = %ld", inputQueue.frame_number);
 
-                       inference<T>(inputQueue.inputs);
+                       inference(inputQueue.inputs);
 
                        LandmarkDetectionResult &resultQueue = result();
 
@@ -316,30 +306,16 @@ template<typename T> void LandmarkDetection::performAsync(LandmarkDetectionInput
                });
        }
 
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(input.inference_src, metaInfo, inputVector);
+       preprocess(input.inference_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
        _async_manager->push(inputVectors);
 }
 
-void LandmarkDetection::performAsync(LandmarkDetectionInput &input)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
-               performAsync<unsigned char>(input, metaInfo);
-       } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
-               performAsync<float>(input, metaInfo);
-               // TODO
-       } else {
-               throw InvalidOperation("Invalid model data type.");
-       }
-}
-
-void LandmarkDetection::getOutputNames(vector<string> &names)
+template<typename T> void LandmarkDetection<T>::getOutputNames(vector<string> &names)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
        IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
@@ -348,7 +324,7 @@ void LandmarkDetection::getOutputNames(vector<string> &names)
                names.push_back(it->first);
 }
 
-void LandmarkDetection::getOutputTensor(string target_name, vector<float> &tensor)
+template<typename T> void LandmarkDetection<T>::getOutputTensor(string target_name, vector<float> &tensor)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
 
@@ -361,18 +337,8 @@ void LandmarkDetection::getOutputTensor(string target_name, vector<float> &tenso
        copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
 }
 
-template void LandmarkDetection::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                  vector<float> &inputVector);
-template void LandmarkDetection::inference<float>(vector<vector<float> > &inputVectors);
-template void LandmarkDetection::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void LandmarkDetection::performAsync<float>(LandmarkDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
-
-template void LandmarkDetection::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                                  vector<unsigned char> &inputVector);
-template void LandmarkDetection::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
-template void LandmarkDetection::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void LandmarkDetection::performAsync<unsigned char>(LandmarkDetectionInput &input,
-                                                                                                                        shared_ptr<MetaInfo> metaInfo);
+template class LandmarkDetection<float>;
+template class LandmarkDetection<unsigned char>;
 
 }
 }
\ No newline at end of file
index d89fcec..c01d1fc 100644 (file)
@@ -31,14 +31,15 @@ namespace mediavision
 {
 namespace machine_learning
 {
-PldCpm::PldCpm(LandmarkDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
-               : LandmarkDetection(task_type, config), _result()
+template<typename T>
+PldCpm<T>::PldCpm(LandmarkDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
+               : LandmarkDetection<T>(task_type, config), _result()
 {}
 
-PldCpm::~PldCpm()
+template<typename T> PldCpm<T>::~PldCpm()
 {}
 
-LandmarkDetectionResult &PldCpm::result()
+template<typename T> LandmarkDetectionResult &PldCpm<T>::result()
 {
        // Clear _result object because result() function can be called every time user wants
        // so make sure to clear existing result data before getting the data again.
@@ -46,7 +47,7 @@ LandmarkDetectionResult &PldCpm::result()
 
        vector<string> names;
 
-       LandmarkDetection::getOutputNames(names);
+       LandmarkDetection<T>::getOutputNames(names);
 
        auto scoreMetaInfo = _config->getOutputMetaMap().at(names[0]);
        auto decodingLandmark =
@@ -68,7 +69,7 @@ LandmarkDetectionResult &PldCpm::result()
 
        _result.number_of_landmarks = heatMapChannel;
 
-       LandmarkDetection::getOutputTensor(names[0], score_tensor);
+       LandmarkDetection<T>::getOutputTensor(names[0], score_tensor);
 
        auto ori_src_width = static_cast<double>(_preprocess.getImageWidth()[0]);
        auto ori_src_height = static_cast<double>(_preprocess.getImageHeight()[0]);
@@ -109,5 +110,8 @@ LandmarkDetectionResult &PldCpm::result()
        return _result;
 }
 
+template class PldCpm<unsigned char>;
+template class PldCpm<float>;
+
 }
 }
\ No newline at end of file
index bc43978..54cb21e 100644 (file)
@@ -32,8 +32,7 @@ template<typename T, typename V> PoseLandmarkAdapter<T, V>::PoseLandmarkAdapter(
        _config = make_shared<MachineLearningConfig>();
        _config->parseConfigFile(_config_file_name);
 
-       LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
-       create(model_type);
+       create(_config->getDefaultModelName());
 }
 
 template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter()
@@ -41,18 +40,35 @@ template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter
        _landmark_detection->preDestroy();
 }
 
-template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
-{
-       // If a concrete class object created already exists, reset the object
-       // so that other concrete class object can be created again according to a given task_type.
-       if (_landmark_detection) {
-               // If default task type is same as a given one then skip.
-               if (_landmark_detection->getTaskType() == task_type)
-                       return;
+template<typename T, typename V>
+template<typename U>
+void PoseLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
+{
+       switch (task_type) {
+       case LandmarkDetectionTaskType::PLD_CPM:
+               _landmark_detection = make_unique<PldCpm<U> >(task_type, _config);
+               break;
+       default:
+               throw InvalidOperation("Invalid landmark detection task type.");
        }
+}
 
-       if (task_type == LandmarkDetectionTaskType::PLD_CPM)
-               _landmark_detection = make_unique<PldCpm>(task_type, _config);
+template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(const string &model_name)
+{
+       LandmarkDetectionTaskType task_type = convertToTaskType(model_name);
+       _config->loadMetaFile(make_unique<LandmarkDetectionParser>(static_cast<int>(task_type)));
+       mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
+
+       switch (dataType) {
+       case MV_INFERENCE_DATA_UINT8:
+               create<unsigned char>(task_type);
+               break;
+       case MV_INFERENCE_DATA_FLOAT32:
+               create<float>(task_type);
+               break;
+       default:
+               throw InvalidOperation("Invalid landmark detection data type.");
+       }
 }
 
 template<typename T, typename V>
@@ -75,9 +91,7 @@ void PoseLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char
 {
        try {
                _config->setUserModel(model_file, meta_file, label_file);
-
-               LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
-               create(model_type);
+               create(model_name);
        } catch (const BaseException &e) {
                LOGW("A given model name is invalid so default task type will be used.");
        }