mv_machine_learning: drop code duplication from object detection
authorInki Dae <inki.dae@samsung.com>
Fri, 1 Dec 2023 08:05:31 +0000 (17:05 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 6 Dec 2023 01:36:46 +0000 (10:36 +0900)
[Issue type] : code cleanup

Drop code duplication from the object detection task group by making the
this task group use MachineLearningNative module instead of internal code
for context management. In addition, this patch fixes user given model issue
by passing the user given model to each task group correctly.

Change-Id: I15993c841ae9ec5b9b1900089f88295d77b0629c
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/common/include/MachineLearningNative.h
mv_machine_learning/common/src/MachineLearningNative.cpp
mv_machine_learning/landmark_detection/src/mv_facial_landmark.cpp
mv_machine_learning/landmark_detection/src/mv_pose_landmark.cpp
mv_machine_learning/object_detection/src/mv_face_detection.cpp
mv_machine_learning/object_detection/src/mv_object_detection.cpp
test/testsuites/machine_learning/object_detection/test_object_detection_async.cpp

index a6355f9..ced1b00 100644 (file)
@@ -40,7 +40,7 @@ void machine_learning_native_inference(void *handle, const std::string &task_nam
 void machine_learning_native_inference_async(void *handle, const std::string &task_name, InputBaseType &input);
 OutputBaseType &machine_learning_native_get_result(void *handle, const std::string &task_name);
 void machine_learning_native_set_model(void *handle, const std::string &task_name, const char *model_file,
-                                                                          const char *meta_file, const char *label_file);
+                                                                          const char *meta_file, const char *label_file, const char *model_name = "");
 void machine_learning_native_set_engine(void *handle, const std::string &task_name, const char *backend_type,
                                                                                const char *device_type);
 void machine_learning_native_get_engine_count(void *handle, const std::string &task_name, unsigned int *engine_count);
index 666eae2..24cb5b0 100644 (file)
@@ -96,11 +96,11 @@ OutputBaseType &machine_learning_native_get_result(void *handle, const string &t
 }
 
 void machine_learning_native_set_model(void *handle, const string &task_name, const char *model_file,
-                                                                          const char *meta_file, const char *label_file)
+                                                                          const char *meta_file, const char *label_file, const char *model_name)
 {
        auto task = get_task(handle, task_name);
 
-       task->setModelInfo(model_file, meta_file, label_file);
+       task->setModelInfo(model_file, meta_file, label_file, model_name);
 }
 
 void machine_learning_native_set_engine(void *handle, const string &task_name, const char *backend_type,
index 8175db3..cd46b2e 100644 (file)
@@ -85,7 +85,7 @@ int mv_facial_landmark_set_model(mv_facial_landmark_h handle, const char *model_
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file, model_name);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
index f440e06..5de2d98 100644 (file)
@@ -85,7 +85,7 @@ int mv_pose_landmark_set_model(mv_pose_landmark_h handle, const char *model_name
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file, model_name);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
index 12128cb..cac09ba 100644 (file)
@@ -19,6 +19,7 @@
 #include "mv_face_detection_internal.h"
 #include "face_detection_adapter.h"
 #include "machine_learning_exception.h"
+#include "MachineLearningNative.h"
 #include "object_detection_type.h"
 #include "context.h"
 
 #include <mutex>
 #include <iostream>
 
+#define TASK_NAME "face_detection"
+
 using namespace std;
 using namespace mediavision::inference;
 using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using FaceDetectionTask = ITask<InputBaseType, OutputBaseType>;
+using FaceDetectionTask = FaceDetectionAdapter<InputBaseType, OutputBaseType>;
 
 int mv_face_detection_create(mv_face_detection_h *handle)
 {
@@ -44,17 +47,9 @@ int mv_face_detection_create(mv_face_detection_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       FaceDetectionTask *task = nullptr;
-
        try {
-               context = new Context();
-               task = new FaceDetectionAdapter<InputBaseType, OutputBaseType>();
-               context->__tasks.insert(make_pair("face_detection", task));
-               *handle = static_cast<mv_face_detection_h>(context);
+               machine_learning_native_create<InputBaseType, OutputBaseType>(TASK_NAME, new FaceDetectionTask(), handle);
        } catch (const BaseException &e) {
-               delete task;
-               delete context;
                return e.getError();
        }
 
@@ -70,12 +65,7 @@ int mv_face_detection_destroy(mv_face_detection_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       auto context = static_cast<Context *>(handle);
-
-       for (auto &m : context->__tasks)
-               delete static_cast<FaceDetectionTask *>(m.second);
-
-       delete context;
+       machine_learning_native_destory(handle, TASK_NAME);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
@@ -96,10 +86,7 @@ int mv_face_detection_set_model(mv_face_detection_h handle, const char *model_na
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->setModelInfo(model_file, meta_file, label_file, model_name);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file, model_name);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -121,10 +108,7 @@ int mv_face_detection_set_engine(mv_face_detection_h handle, const char *backend
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->setEngineInfo(backend_type, device_type);
+               machine_learning_native_set_engine(handle, TASK_NAME, backend_type, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -145,11 +129,7 @@ int mv_face_detection_get_engine_count(mv_face_detection_h handle, unsigned int
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->getNumberOfEngines(engine_count);
-               // TODO.
+               machine_learning_native_get_engine_count(handle, TASK_NAME, engine_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -170,11 +150,7 @@ int mv_face_detection_get_engine_type(mv_face_detection_h handle, const unsigned
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->getEngineType(engine_index, engine_type);
-               // TODO.
+               machine_learning_native_get_engine_type(handle, TASK_NAME, engine_index, engine_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -195,11 +171,7 @@ int mv_face_detection_get_device_count(mv_face_detection_h handle, const char *e
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->getNumberOfDevices(engine_type, device_count);
-               // TODO.
+               machine_learning_native_get_device_count(handle, TASK_NAME, engine_type, device_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -222,11 +194,7 @@ int mv_face_detection_get_device_type(mv_face_detection_h handle, const char *en
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->getDeviceType(engine_type, device_index, device_type);
-               // TODO.
+               machine_learning_native_get_device_type(handle, TASK_NAME, engine_type, device_index, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -245,10 +213,7 @@ int mv_face_detection_configure(mv_face_detection_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->configure();
+               machine_learning_native_configure(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -267,10 +232,7 @@ int mv_face_detection_prepare(mv_face_detection_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               task->prepare();
+               machine_learning_native_prepare(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -290,13 +252,9 @@ int mv_face_detection_inference(mv_face_detection_h handle, mv_source_h source)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
                ObjectDetectionInput input(source);
 
-               task->setInput(input);
-               task->perform();
+               machine_learning_native_inference(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -315,18 +273,10 @@ int mv_face_detection_inference_async(mv_face_detection_h handle, mv_source_h so
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       if (!handle) {
-               LOGE("Handle is NULL.");
-               return MEDIA_VISION_ERROR_INVALID_PARAMETER;
-       }
-
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
                ObjectDetectionInput input(source);
 
-               task->performAsync(input);
+               machine_learning_native_inference_async(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -354,10 +304,7 @@ int mv_face_detection_get_result(mv_face_detection_h handle, unsigned int *numbe
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
-
-               auto &result = static_cast<ObjectDetectionResult &>(task->getOutput());
+               auto &result = static_cast<ObjectDetectionResult &>(machine_learning_native_get_result(handle, TASK_NAME));
                *number_of_objects = result.number_of_objects;
                *frame_number = result.frame_number;
                *confidences = result.confidences.data();
index 7f1ac0c..a86eff6 100644 (file)
@@ -19,6 +19,7 @@
 #include "mv_object_detection_internal.h"
 #include "object_detection_adapter.h"
 #include "machine_learning_exception.h"
+#include "MachineLearningNative.h"
 #include "object_detection_type.h"
 #include "context.h"
 
 #include <mutex>
 #include <iostream>
 
+#define TASK_NAME "object_detection"
+
 using namespace std;
 using namespace mediavision::inference;
 using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using ObjectDetectionTask = ITask<InputBaseType, OutputBaseType>;
+using ObjectDetectionTask = ObjectDetectionAdapter<InputBaseType, OutputBaseType>;
 
 int mv_object_detection_create(mv_object_detection_h *handle)
 {
@@ -44,17 +47,9 @@ int mv_object_detection_create(mv_object_detection_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       ObjectDetectionTask *task = nullptr;
-
        try {
-               context = new Context();
-               task = new ObjectDetectionAdapter<InputBaseType, OutputBaseType>();
-               context->__tasks.insert(make_pair("object_detection", task));
-               *handle = static_cast<mv_object_detection_h>(context);
+               machine_learning_native_create<InputBaseType, OutputBaseType>(TASK_NAME, new ObjectDetectionTask(), handle);
        } catch (const BaseException &e) {
-               delete task;
-               delete context;
                return e.getError();
        }
 
@@ -70,12 +65,7 @@ int mv_object_detection_destroy(mv_object_detection_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       auto context = static_cast<Context *>(handle);
-
-       for (auto &m : context->__tasks)
-               delete static_cast<ObjectDetectionTask *>(m.second);
-
-       delete context;
+       machine_learning_native_destory(handle, TASK_NAME);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
@@ -92,10 +82,7 @@ int mv_object_detection_set_model(mv_object_detection_h handle, const char *mode
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->setModelInfo(model_file, meta_file, label_file, model_name);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file, model_name);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -117,10 +104,7 @@ int mv_object_detection_set_engine(mv_object_detection_h handle, const char *bac
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->setEngineInfo(backend_type, device_type);
+               machine_learning_native_set_engine(handle, TASK_NAME, backend_type, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -141,11 +125,7 @@ int mv_object_detection_get_engine_count(mv_object_detection_h handle, unsigned
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->getNumberOfEngines(engine_count);
-               // TODO.
+               machine_learning_native_get_engine_count(handle, TASK_NAME, engine_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -167,11 +147,7 @@ int mv_object_detection_get_engine_type(mv_object_detection_h handle, const unsi
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->getEngineType(engine_index, engine_type);
-               // TODO.
+               machine_learning_native_get_engine_type(handle, TASK_NAME, engine_index, engine_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -193,11 +169,7 @@ int mv_object_detection_get_device_count(mv_object_detection_h handle, const cha
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->getNumberOfDevices(engine_type, device_count);
-               // TODO.
+               machine_learning_native_get_device_count(handle, TASK_NAME, engine_type, device_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -220,11 +192,7 @@ int mv_object_detection_get_device_type(mv_object_detection_h handle, const char
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->getDeviceType(engine_type, device_index, device_type);
-               // TODO.
+               machine_learning_native_get_device_type(handle, TASK_NAME, engine_type, device_index, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -243,10 +211,7 @@ int mv_object_detection_configure(mv_object_detection_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->configure();
+               machine_learning_native_configure(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -265,10 +230,7 @@ int mv_object_detection_prepare(mv_object_detection_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               task->prepare();
+               machine_learning_native_prepare(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -288,13 +250,9 @@ int mv_object_detection_inference(mv_object_detection_h handle, mv_source_h sour
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
                ObjectDetectionInput input(source);
 
-               task->setInput(input);
-               task->perform();
+               machine_learning_native_inference(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -314,12 +272,9 @@ int mv_object_detection_inference_async(mv_object_detection_h handle, mv_source_
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
                ObjectDetectionInput input(source);
 
-               task->performAsync(input);
+               machine_learning_native_inference_async(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -347,10 +302,7 @@ int mv_object_detection_get_result(mv_object_detection_h handle, unsigned int *n
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               auto &result = static_cast<ObjectDetectionResult &>(task->getOutput());
+               auto &result = static_cast<ObjectDetectionResult &>(machine_learning_native_get_result(handle, TASK_NAME));
                *number_of_objects = result.number_of_objects;
                *frame_number = result.frame_number;
                *confidences = result.confidences.data();
@@ -377,10 +329,7 @@ int mv_object_detection_get_label(mv_object_detection_h handle, const unsigned i
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
-
-               auto &result = static_cast<ObjectDetectionResult &>(task->getOutputCache());
+               auto &result = static_cast<ObjectDetectionResult &>(machine_learning_native_get_result(handle, TASK_NAME));
 
                if (result.number_of_objects <= index)
                        throw InvalidParameter("Invalid index range.");
index 7651a8e..a6cdda9 100644 (file)
@@ -68,6 +68,9 @@ void object_detection_callback(void *user_data)
                        const char *label;
 
                        ret = mv_object_detection_get_label(handle, idx, &label);
+                       if (ret == MEDIA_VISION_ERROR_INVALID_OPERATION)
+                               break;
+
                        ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
                        cout << "index = " << idx << " label = " << label << endl;