mv_machine_learning: drop create member from itask
authorInki Dae <inki.dae@samsung.com>
Thu, 27 Jul 2023 06:58:47 +0000 (15:58 +0900)
committerKwanghoon Son <k.son@samsung.com>
Mon, 7 Aug 2023 04:25:06 +0000 (13:25 +0900)
[Issue type] : code cleanup

Drop create interface from itask interface class.

The create interface isn't common but specific to certain task group.
Therefore, drop the interface from itask interface class, and change
the create member function of each adapter class to private member if needed.

Change-Id: I65264ede3d5b627c04b7fd6d8cb7555677b0d665
Signed-off-by: Inki Dae <inki.dae@samsung.com>
17 files changed:
mv_machine_learning/common/include/itask.h
mv_machine_learning/face_recognition/include/face_recognition_adapter.h
mv_machine_learning/face_recognition/include/facenet_adapter.h
mv_machine_learning/face_recognition/src/face_recognition_adapter.cpp
mv_machine_learning/face_recognition/src/facenet_adapter.cpp
mv_machine_learning/image_classification/include/image_classification_adapter.h
mv_machine_learning/image_classification/src/image_classification_adapter.cpp
mv_machine_learning/landmark_detection/include/facial_landmark_adapter.h
mv_machine_learning/landmark_detection/include/pose_landmark_adapter.h
mv_machine_learning/landmark_detection/src/facial_landmark_adapter.cpp
mv_machine_learning/landmark_detection/src/pose_landmark_adapter.cpp
mv_machine_learning/object_detection/include/face_detection_adapter.h
mv_machine_learning/object_detection/include/object_detection_adapter.h
mv_machine_learning/object_detection/src/face_detection_adapter.cpp
mv_machine_learning/object_detection/src/object_detection_adapter.cpp
mv_machine_learning/object_detection_3d/include/object_detection_3d_adapter.h
mv_machine_learning/object_detection_3d/src/object_detection_3d_adapter.cpp

index 2c32c10..bfad5a9 100644 (file)
@@ -26,7 +26,6 @@ template<typename T, typename V> class ITask
 {
 public:
        virtual ~ITask() {};
-       virtual void create(int type) = 0;
        virtual void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                                          const char *model_name = "") = 0;
        virtual void setEngineInfo(const char *engine_type, const char *device_type) = 0;
index 2de90a7..b6acf63 100644 (file)
@@ -43,7 +43,6 @@ public:
                return _config;
        }
 
-       void create(int type) override;
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index 001eb84..17e1b26 100644 (file)
@@ -37,7 +37,6 @@ public:
        FacenetAdapter();
        ~FacenetAdapter();
 
-       void create(int type) override;
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index 21d555c..a6126ef 100644 (file)
@@ -35,11 +35,6 @@ template<typename T, typename V> FaceRecognitionAdapter<T, V>::FaceRecognitionAd
 template<typename T, typename V> FaceRecognitionAdapter<T, V>::~FaceRecognitionAdapter()
 {}
 
-template<typename T, typename V> void FaceRecognitionAdapter<T, V>::create(int type)
-{
-       throw InvalidOperation("Not support yet.");
-}
-
 template<typename T, typename V>
 void FaceRecognitionAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                                                                                const char *model_name)
index a65c341..7f16ba5 100644 (file)
@@ -34,11 +34,6 @@ template<typename T, typename V> FacenetAdapter<T, V>::FacenetAdapter() : _sourc
 template<typename T, typename V> FacenetAdapter<T, V>::~FacenetAdapter()
 {}
 
-template<typename T, typename V> void FacenetAdapter<T, V>::create(int type)
-{
-       throw InvalidOperation("Not support yet.");
-}
-
 template<typename T, typename V>
 void FacenetAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                                                                const char *model_name)
index 75f88d9..45e493d 100644 (file)
@@ -37,7 +37,6 @@ public:
        ImageClassificationAdapter();
        ~ImageClassificationAdapter();
 
-       void create(int type) override;
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index 29e2ae5..4c59dd3 100644 (file)
@@ -37,11 +37,6 @@ template<typename T, typename V> ImageClassificationAdapter<T, V>::ImageClassifi
 template<typename T, typename V> ImageClassificationAdapter<T, V>::~ImageClassificationAdapter()
 {}
 
-template<typename T, typename V> void ImageClassificationAdapter<T, V>::create(int type)
-{
-       throw InvalidOperation("Interface not supported.");
-}
-
 template<typename T, typename V>
 void ImageClassificationAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file,
                                                                                                        const char *label_file, const char *model_name)
index 49abb76..1ea521d 100644 (file)
@@ -37,12 +37,12 @@ private:
        std::string _meta_file;
        std::string _label_file;
 
+       void create(LandmarkDetectionTaskType task_type);
+
 public:
        FacialLandmarkAdapter();
        ~FacialLandmarkAdapter();
 
-       void create(int type) override;
-
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index 99456df..f2b7f53 100644 (file)
@@ -37,12 +37,12 @@ private:
        std::string _meta_file;
        std::string _label_file;
 
+       void create(LandmarkDetectionTaskType task_type);
+
 public:
        PoseLandmarkAdapter();
        ~PoseLandmarkAdapter();
 
-       void create(int type) override;
-
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index 6d2943a..9fb19e0 100644 (file)
@@ -37,10 +37,8 @@ template<typename T, typename V> FacialLandmarkAdapter<T, V>::FacialLandmarkAdap
 template<typename T, typename V> FacialLandmarkAdapter<T, V>::~FacialLandmarkAdapter()
 {}
 
-template<typename T, typename V> void FacialLandmarkAdapter<T, V>::create(int type)
+template<typename T, typename V> void FacialLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
 {
-       LandmarkDetectionTaskType task_type = static_cast<LandmarkDetectionTaskType>(type);
-
        // If default task type is same as a given one then skip.
        if (_landmark_detection->getTaskType() == task_type)
                return;
@@ -61,15 +59,15 @@ void FacialLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const cha
        if (!model_name_str.empty()) {
                transform(model_name_str.begin(), model_name_str.end(), model_name_str.begin(), ::toupper);
 
-               int model_type = 0;
+               LandmarkDetectionTaskType task_type = LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE;
 
                if (model_name_str == string("FLD_TWEAK_CNN"))
-                       model_type = static_cast<int>(LandmarkDetectionTaskType::FLD_TWEAK_CNN);
+                       task_type = LandmarkDetectionTaskType::FLD_TWEAK_CNN;
                // TODO.
                else
                        throw InvalidParameter("Invalid landmark detection model name.");
 
-               create(static_cast<int>(model_type));
+               create(task_type);
        }
 
        _model_file = string(model_file);
index 3ba82d1..5fb6c31 100644 (file)
@@ -37,10 +37,8 @@ template<typename T, typename V> PoseLandmarkAdapter<T, V>::PoseLandmarkAdapter(
 template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter()
 {}
 
-template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(int type)
+template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
 {
-       LandmarkDetectionTaskType task_type = static_cast<LandmarkDetectionTaskType>(type);
-
        // If default task type is same as a given one then skip.
        if (_landmark_detection->getTaskType() == task_type)
                return;
@@ -61,15 +59,15 @@ void PoseLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char
        if (!model_name_str.empty()) {
                transform(model_name_str.begin(), model_name_str.end(), model_name_str.begin(), ::toupper);
 
-               int model_type = 0;
+               LandmarkDetectionTaskType task_type = LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE;
 
                if (model_name_str == string("PLD_CPM"))
-                       model_type = static_cast<int>(LandmarkDetectionTaskType::PLD_CPM);
+                       task_type = LandmarkDetectionTaskType::PLD_CPM;
                // TODO.
                else
                        throw InvalidParameter("Invalid landmark detection model name.");
 
-               create(static_cast<int>(model_type));
+               create(task_type);
        }
 
        _model_file = string(model_file);
index d71bbe3..50e32d1 100644 (file)
@@ -37,12 +37,12 @@ private:
        std::string _meta_file;
        std::string _label_file;
 
+       void create(ObjectDetectionTaskType task_type);
+
 public:
        FaceDetectionAdapter();
        ~FaceDetectionAdapter();
 
-       void create(int type) override;
-
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index 5975c68..7aa6ff4 100644 (file)
@@ -38,12 +38,12 @@ private:
        std::string _meta_file;
        std::string _label_file;
 
+       void create(ObjectDetectionTaskType task_type);
+
 public:
        ObjectDetectionAdapter();
        ~ObjectDetectionAdapter();
 
-       void create(int type) override;
-
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index e718305..239a6d6 100644 (file)
@@ -39,10 +39,8 @@ template<typename T, typename V> FaceDetectionAdapter<T, V>::~FaceDetectionAdapt
        _object_detection->preDestroy();
 }
 
-template<typename T, typename V> void FaceDetectionAdapter<T, V>::create(int type)
+template<typename T, typename V> void FaceDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
 {
-       ObjectDetectionTaskType task_type = static_cast<ObjectDetectionTaskType>(type);
-
        // If default task type is same as a given one then skip.
        if (_object_detection->getTaskType() == task_type)
                return;
@@ -63,15 +61,15 @@ void FaceDetectionAdapter<T, V>::setModelInfo(const char *model_file, const char
        if (!model_name_str.empty()) {
                transform(model_name_str.begin(), model_name_str.end(), model_name_str.begin(), ::toupper);
 
-               int model_type = 0;
+               ObjectDetectionTaskType task_type = ObjectDetectionTaskType::OBJECT_DETECTION_TASK_NONE;
 
                if (model_name_str == string("FD_MOBILENET_V1_SSD"))
-                       model_type = static_cast<int>(ObjectDetectionTaskType::FD_MOBILENET_V1_SSD);
+                       task_type = ObjectDetectionTaskType::FD_MOBILENET_V1_SSD;
                // TODO.
                else
                        throw InvalidParameter("Invalid face detection model name.");
 
-               create(static_cast<int>(model_type));
+               create(task_type);
        }
 
        _model_file = string(model_file);
index 6165c4a..57993b3 100644 (file)
@@ -40,10 +40,8 @@ template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionA
        _object_detection->preDestroy();
 }
 
-template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(int type)
+template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
 {
-       ObjectDetectionTaskType task_type = static_cast<ObjectDetectionTaskType>(type);
-
        // If default task type is same as a given one then skip.
        if (_object_detection->getTaskType() == task_type)
                return;
@@ -68,30 +66,30 @@ void ObjectDetectionAdapter<T, V>::setModelInfo(const char *model_file, const ch
        if (!model_name_str.empty()) {
                transform(model_name_str.begin(), model_name_str.end(), model_name_str.begin(), ::toupper);
 
-               int model_type = 0;
+               ObjectDetectionTaskType task_type = ObjectDetectionTaskType::OBJECT_DETECTION_TASK_NONE;
 
                if (model_name_str == "OD_PLUGIN" || model_name_str == "FD_PLUGIN") {
                        if (model_name_str == "OD_PLUGIN")
-                               model_type = static_cast<int>(ObjectDetectionTaskType::OD_PLUGIN);
+                               task_type = ObjectDetectionTaskType::OD_PLUGIN;
                        if (model_name_str == "FD_PLUGIN")
-                               model_type = static_cast<int>(ObjectDetectionTaskType::FD_PLUGIN);
+                               task_type = ObjectDetectionTaskType::FD_PLUGIN;
 
                        // In case of using plugin module, model information will be managed by the plugin module.
                        // Therefore, create plugin instance now.
-                       create(model_type);
+                       create(task_type);
                        return;
                }
 
                if (model_name_str == string("MOBILENET_V1_SSD")) {
-                       model_type = static_cast<int>(ObjectDetectionTaskType::MOBILENET_V1_SSD);
+                       task_type = ObjectDetectionTaskType::MOBILENET_V1_SSD;
                } else if (model_name_str == string("MOBILENET_V2_SSD")) {
-                       model_type = static_cast<int>(ObjectDetectionTaskType::MOBILENET_V2_SSD);
+                       task_type = ObjectDetectionTaskType::MOBILENET_V2_SSD;
                        // TODO.
                } else {
                        throw InvalidParameter("Invalid object detection model name.");
                }
 
-               create(model_type);
+               create(task_type);
        }
 
        _model_file = string(model_file);
index 248f27c..542dc85 100644 (file)
@@ -37,11 +37,12 @@ private:
        std::string _meta_file;
        std::string _label_file;
 
+       void create(ObjectDetection3dTaskType task_type);
+
 public:
        ObjectDetection3dAdapter();
        ~ObjectDetection3dAdapter();
 
-       void create(int type) override;
        void setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
                                          const char *model_name) override;
        void setEngineInfo(const char *engine_type, const char *device_type) override;
index b782a39..b72274f 100644 (file)
@@ -34,7 +34,7 @@ template<typename T, typename V> ObjectDetection3dAdapter<T, V>::ObjectDetection
 template<typename T, typename V> ObjectDetection3dAdapter<T, V>::~ObjectDetection3dAdapter()
 {}
 
-template<typename T, typename V> void ObjectDetection3dAdapter<T, V>::create(int type)
+template<typename T, typename V> void ObjectDetection3dAdapter<T, V>::create(ObjectDetection3dTaskType task_type)
 {}
 
 template<typename T, typename V>
@@ -46,14 +46,14 @@ void ObjectDetection3dAdapter<T, V>::setModelInfo(const char *model_file, const
        if (!model_name_str.empty()) {
                transform(model_name_str.begin(), model_name_str.end(), model_name_str.begin(), ::toupper);
 
-               int model_type = 0;
+               ObjectDetection3dTaskType task_type = ObjectDetection3dTaskType::OBJECT_DETECTION_3D_TASK_NONE;
 
                if (model_name_str == string("OBJECTRON"))
-                       model_type = static_cast<int>(ObjectDetection3dTaskType::OBJECTRON);
+                       task_type = ObjectDetection3dTaskType::OBJECTRON;
                else
                        throw InvalidParameter("Invalid object detection 3d model name.");
 
-               create(static_cast<int>(model_type));
+               create(task_type);
        }
 
        _model_file = string(model_file);