mv_machine_learning: drop input and output type dependency from object detection
authorInki Dae <inki.dae@samsung.com>
Thu, 30 Nov 2023 06:34:24 +0000 (15:34 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 6 Dec 2023 01:36:46 +0000 (10:36 +0900)
[Issue type] : code refactoring

Drop input and output type dependency from object detection task group
by making the input and output types specific to the object detection
task group to be inherited from the common types,
 and then by making adapter class of the object detection task group
to use the common type instead of specific one.

Change-Id: Ie1a397cd8fc05bd507497f04637d606815c5cccc
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/object_detection/include/object_detection_type.h
mv_machine_learning/object_detection/src/face_detection_adapter.cpp
mv_machine_learning/object_detection/src/mv_face_detection.cpp
mv_machine_learning/object_detection/src/mv_object_detection.cpp
mv_machine_learning/object_detection/src/object_detection_adapter.cpp

index a897b9b..a653b25 100644 (file)
 
 #include <mv_common.h>
 #include <mv_inference_type.h>
-#include <mv_object_detection_type.h>
+#include "MachineLearningType.h"
 
 namespace mediavision
 {
 namespace machine_learning
 {
-struct ObjectDetectionInput {
-       mv_source_h inference_src {};
-       // TODO.
+struct ObjectDetectionInput : public InputBaseType {
+       ObjectDetectionInput(mv_source_h src = NULL) : InputBaseType(src)
+       {}
 };
 
 /**
  * @brief The object detection result structure.
  * @details Contains object detection result.
  */
-struct ObjectDetectionResult {
-       unsigned long frame_number {};
+struct ObjectDetectionResult : public OutputBaseType {
        unsigned int number_of_objects {};
        std::vector<unsigned int> indices;
        std::vector<std::string> names;
index 6e4be35..d50a5fb 100644 (file)
@@ -176,7 +176,7 @@ template<typename T, typename V> void FaceDetectionAdapter<T, V>::perform()
 
 template<typename T, typename V> void FaceDetectionAdapter<T, V>::performAsync(T &t)
 {
-       _object_detection->performAsync(t);
+       _object_detection->performAsync(static_cast<ObjectDetectionInput &>(t));
 }
 
 template<typename T, typename V> V &FaceDetectionAdapter<T, V>::getOutput()
@@ -189,6 +189,6 @@ template<typename T, typename V> V &FaceDetectionAdapter<T, V>::getOutputCache()
        return _object_detection->getOutputCache();
 }
 
-template class FaceDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>;
+template class FaceDetectionAdapter<InputBaseType, OutputBaseType>;
 }
 }
index 35b2b59..12128cb 100644 (file)
@@ -35,7 +35,7 @@ using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using FaceDetectionTask = ITask<ObjectDetectionInput, ObjectDetectionResult>;
+using FaceDetectionTask = ITask<InputBaseType, OutputBaseType>;
 
 int mv_face_detection_create(mv_face_detection_h *handle)
 {
@@ -49,7 +49,7 @@ int mv_face_detection_create(mv_face_detection_h *handle)
 
        try {
                context = new Context();
-               task = new FaceDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>();
+               task = new FaceDetectionAdapter<InputBaseType, OutputBaseType>();
                context->__tasks.insert(make_pair("face_detection", task));
                *handle = static_cast<mv_face_detection_h>(context);
        } catch (const BaseException &e) {
@@ -293,7 +293,7 @@ int mv_face_detection_inference(mv_face_detection_h handle, mv_source_h source)
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
 
-               ObjectDetectionInput input = { .inference_src = source };
+               ObjectDetectionInput input(source);
 
                task->setInput(input);
                task->perform();
@@ -324,7 +324,7 @@ int mv_face_detection_inference_async(mv_face_detection_h handle, mv_source_h so
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
 
-               ObjectDetectionInput input = { source };
+               ObjectDetectionInput input(source);
 
                task->performAsync(input);
        } catch (const BaseException &e) {
@@ -357,7 +357,7 @@ int mv_face_detection_get_result(mv_face_detection_h handle, unsigned int *numbe
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
 
-               ObjectDetectionResult &result = task->getOutput();
+               auto &result = static_cast<ObjectDetectionResult &>(task->getOutput());
                *number_of_objects = result.number_of_objects;
                *frame_number = result.frame_number;
                *confidences = result.confidences.data();
@@ -387,7 +387,7 @@ int mv_face_detection_get_label(mv_face_detection_h handle, const unsigned int i
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
 
-               ObjectDetectionResult &result = task->getOutputCache();
+               auto &result = static_cast<ObjectDetectionResult &>(task->getOutputCache());
 
                if (result.number_of_objects <= index)
                        throw InvalidParameter("Invalid index range.");
index e90696a..7f1ac0c 100644 (file)
@@ -35,7 +35,7 @@ using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using ObjectDetectionTask = ITask<ObjectDetectionInput, ObjectDetectionResult>;
+using ObjectDetectionTask = ITask<InputBaseType, OutputBaseType>;
 
 int mv_object_detection_create(mv_object_detection_h *handle)
 {
@@ -49,7 +49,7 @@ int mv_object_detection_create(mv_object_detection_h *handle)
 
        try {
                context = new Context();
-               task = new ObjectDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>();
+               task = new ObjectDetectionAdapter<InputBaseType, OutputBaseType>();
                context->__tasks.insert(make_pair("object_detection", task));
                *handle = static_cast<mv_object_detection_h>(context);
        } catch (const BaseException &e) {
@@ -291,7 +291,7 @@ int mv_object_detection_inference(mv_object_detection_h handle, mv_source_h sour
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
 
-               ObjectDetectionInput input = { .inference_src = source };
+               ObjectDetectionInput input(source);
 
                task->setInput(input);
                task->perform();
@@ -317,7 +317,7 @@ int mv_object_detection_inference_async(mv_object_detection_h handle, mv_source_
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
 
-               ObjectDetectionInput input = { source };
+               ObjectDetectionInput input(source);
 
                task->performAsync(input);
        } catch (const BaseException &e) {
@@ -350,7 +350,7 @@ int mv_object_detection_get_result(mv_object_detection_h handle, unsigned int *n
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
 
-               ObjectDetectionResult &result = task->getOutput();
+               auto &result = static_cast<ObjectDetectionResult &>(task->getOutput());
                *number_of_objects = result.number_of_objects;
                *frame_number = result.frame_number;
                *confidences = result.confidences.data();
@@ -380,7 +380,7 @@ int mv_object_detection_get_label(mv_object_detection_h handle, const unsigned i
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
 
-               ObjectDetectionResult &result = task->getOutputCache();
+               auto &result = static_cast<ObjectDetectionResult &>(task->getOutputCache());
 
                if (result.number_of_objects <= index)
                        throw InvalidParameter("Invalid index range.");
index 3d60ffa..cce24f4 100644 (file)
@@ -191,9 +191,9 @@ template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutputCache
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::performAsync(T &t)
 {
-       _object_detection->performAsync(t);
+       _object_detection->performAsync(static_cast<ObjectDetectionInput &>(t));
 }
 
-template class ObjectDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>;
+template class ObjectDetectionAdapter<InputBaseType, OutputBaseType>;
 }
 }
\ No newline at end of file