mv_machine_learning: add multiple tasks support to MachineLearningNative module
authorInki Dae <inki.dae@samsung.com>
Fri, 8 Dec 2023 07:18:03 +0000 (16:18 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 27 Dec 2023 03:22:49 +0000 (12:22 +0900)
[Issue type] : new feature

Add multiple tasks support to MachineLearningNative module. In case of
face recognition task group, two tasks - facenet and face_recognition -
are needed. Therefore, this patch adds a new API,
machine_learning_native_add() which adds a given task object to
the given context, to MachineLearningNative module, changes return type of
machine_learning_native_create() to void * type for API consistency, and
updates other task groups to call this API.

In addition, this patch drops a redundant parameter, task_name, from
machine_learning_native_destroy().

Change-Id: I9c3ca2645dc876a3c7eb31bb9ce0aca56d376ad9
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/common/include/MachineLearningNative.h
mv_machine_learning/common/src/MachineLearningNative.cpp
mv_machine_learning/face_recognition/src/mv_face_recognition.cpp
mv_machine_learning/image_classification/src/mv_image_classification.cpp
mv_machine_learning/landmark_detection/src/mv_facial_landmark.cpp
mv_machine_learning/landmark_detection/src/mv_pose_landmark.cpp
mv_machine_learning/object_detection/src/mv_face_detection.cpp
mv_machine_learning/object_detection/src/mv_object_detection.cpp

index 67d21cf..9086a83 100644 (file)
@@ -30,8 +30,9 @@ namespace mediavision
 {
 namespace machine_learning
 {
-void machine_learning_native_create(const std::string &task_name, mediavision::common::ITask *task, void **handle);
-void machine_learning_native_destory(void *handle, const std::string &task_name);
+void *machine_learning_native_create();
+void machine_learning_native_add(void *handle, const std::string &task_name, mediavision::common::ITask *task);
+void machine_learning_native_destory(void *handle);
 void machine_learning_native_configure(void *handle, const std::string &task_name);
 void machine_learning_native_prepare(void *handle, const std::string &task_name);
 void machine_learning_native_inference(void *handle, const std::string &task_name, InputBaseType &input);
index f1c714c..d361edb 100644 (file)
@@ -35,15 +35,19 @@ inline ITask *get_task(void *handle, const std::string &name)
        return context->__tasks.at(name);
 }
 
-void machine_learning_native_create(const string &task_name, ITask *task, void **handle)
+void *machine_learning_native_create()
 {
-       Context *context = new Context();
+       return static_cast<void *>(new Context());
+}
+
+void machine_learning_native_add(void *handle, const string &task_name, ITask *task)
+{
+       auto context = static_cast<Context *>(handle);
 
        context->__tasks.insert(make_pair(task_name, task));
-       *handle = static_cast<void *>(context);
 }
 
-void machine_learning_native_destory(void *handle, const string &task_name)
+void machine_learning_native_destory(void *handle)
 {
        auto context = static_cast<Context *>(handle);
 
index 7a2a870..ee257dd 100644 (file)
@@ -25,6 +25,7 @@
 #include "machine_learning_exception.h"
 #include "face_recognition_adapter.h"
 #include "facenet_adapter.h"
+#include "MachineLearningNative.h"
 #include "mv_face_recognition.h"
 #include "mv_face_recognition_internal.h"
 
@@ -50,27 +51,19 @@ int mv_face_recognition_create(mv_face_recognition_h *out_handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       ITask *face_recognition_task = nullptr;
-       ITask *facenet_task = nullptr;
+       mv_face_recognition_h ctx = nullptr;
 
        try {
-               context = new Context();
-               face_recognition_task = new FaceRecognitionAdapter();
-               facenet_task = new FacenetAdapter();
-               context->__tasks.insert(make_pair("face_recognition", face_recognition_task));
-               context->__tasks.insert(make_pair("facenet", facenet_task));
-
-               *out_handle = static_cast<mv_face_recognition_h>(context);
-
-               LOGD("face recognition handle [%p] has been created", *out_handle);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, "face_recognition", new FaceRecognitionAdapter());
+               machine_learning_native_add(ctx, "facenet", new FacenetAdapter());
        } catch (const BaseException &e) {
-               delete face_recognition_task;
-               delete facenet_task;
-               delete context;
                return e.getError();
        }
 
+       *out_handle = ctx;
+       LOGD("face recognition handle [%p] has been created", *out_handle);
+
        MEDIA_VISION_FUNCTION_LEAVE();
 
        return MEDIA_VISION_ERROR_NONE;
@@ -86,19 +79,12 @@ int mv_face_recognition_destroy(mv_face_recognition_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = static_cast<Context *>(handle);
-       map<string, ITask *>::iterator iter;
-
-       for (iter = context->__tasks.begin(); iter != context->__tasks.end(); ++iter) {
-               if (iter->first.compare("face_recognition") == 0)
-                       delete iter->second;
-
-               if (iter->first.compare("facenet") == 0)
-                       delete iter->second;
+       try {
+               machine_learning_native_destory(handle);
+       } catch (const BaseException &e) {
+               return e.getError();
        }
 
-       delete context;
-
        LOGD("Face recognition handle has been destroyed");
 
        MEDIA_VISION_FUNCTION_LEAVE();
@@ -117,14 +103,11 @@ int mv_face_recognition_prepare(mv_face_recognition_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               Context *context = static_cast<Context *>(handle);
-               auto face_recognition_task = context->__tasks["face_recognition"];
-               auto facenet_task = context->__tasks["facenet"];
-
-               face_recognition_task->configure();
-               facenet_task->configure();
-               face_recognition_task->prepare();
-               facenet_task->prepare();
+               machine_learning_native_configure(handle, "face_recognition");
+               machine_learning_native_configure(handle, "facenet");
+
+               machine_learning_native_prepare(handle, "face_recognition");
+               machine_learning_native_prepare(handle, "facenet");
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -148,22 +131,17 @@ int mv_face_recognition_register(mv_face_recognition_h handle, mv_source_h sourc
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               Context *context = static_cast<Context *>(handle);
-               auto face_recognition_task = context->__tasks["face_recognition"];
-               auto facenet_task = context->__tasks["facenet"];
                FacenetInput facenet_input(source);
 
-               facenet_task->perform(facenet_input);
+               machine_learning_native_inference(handle, "facenet", facenet_input);
 
-               auto &facenet_output = static_cast<FacenetOutput &>(facenet_task->getOutput());
+               auto &facenet_output = static_cast<FacenetOutput &>(machine_learning_native_get_result(handle, "facenet"));
                FaceRecognitionInput face_recognition_input;
 
                face_recognition_input.mode = RequestMode::REGISTER;
-
                face_recognition_input.inputs.push_back(facenet_output.outputs[0]);
                face_recognition_input.labels.push_back(label);
-
-               face_recognition_task->perform(face_recognition_input);
+               machine_learning_native_inference(handle, "face_recognition", face_recognition_input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -186,15 +164,12 @@ int mv_face_recognition_unregister(mv_face_recognition_h handle, const char *lab
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               Context *context = static_cast<Context *>(handle);
-               auto face_recognition_task = context->__tasks["face_recognition"];
                FaceRecognitionInput input;
 
                input.mode = RequestMode::DELETE;
-
                input.labels.clear();
                input.labels.push_back(label);
-               face_recognition_task->perform(input);
+               machine_learning_native_inference(handle, "face_recognition", input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -217,20 +192,16 @@ int mv_face_recognition_inference(mv_face_recognition_h handle, mv_source_h sour
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               Context *context = static_cast<Context *>(handle);
-               auto face_recognition_task = context->__tasks["face_recognition"];
-               auto facenet_task = context->__tasks["facenet"];
                FacenetInput facenet_input(source);
 
-               facenet_task->perform(facenet_input);
+               machine_learning_native_inference(handle, "facenet", facenet_input);
 
-               auto &facenet_output = static_cast<FacenetOutput &>(facenet_task->getOutput());
+               auto &facenet_output = static_cast<FacenetOutput &>(machine_learning_native_get_result(handle, "facenet"));
                FaceRecognitionInput face_recognition_input;
 
                face_recognition_input.mode = RequestMode::INFERENCE;
-
                face_recognition_input.inputs = facenet_output.outputs;
-               face_recognition_task->perform(face_recognition_input);
+               machine_learning_native_inference(handle, "face_recognition", face_recognition_input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -253,9 +224,8 @@ int mv_face_recognition_get_label(mv_face_recognition_h handle, const char **out
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               Context *context = static_cast<Context *>(handle);
-               auto face_recognition_task = context->__tasks["face_recognition"];
-               auto &result = static_cast<FaceRecognitionResult &>(face_recognition_task->getOutput());
+               auto &result =
+                               static_cast<FaceRecognitionResult &>(machine_learning_native_get_result(handle, "face_recognition"));
 
                *out_label = result.label.c_str();
        } catch (const BaseException &e) {
@@ -282,9 +252,8 @@ int mv_face_recognition_get_confidence(mv_face_recognition_h handle, const float
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto face_recognition_task = context->__tasks["face_recognition"];
-               auto &result = static_cast<FaceRecognitionResult &>(face_recognition_task->getOutput());
+               auto &result =
+                               static_cast<FaceRecognitionResult &>(machine_learning_native_get_result(handle, "face_recognition"));
 
                *confidences = result.raw_data.data();
                *num_of_confidences = result.raw_data.size();
@@ -310,9 +279,8 @@ int mv_face_recognition_get_label_with_index(mv_face_recognition_h handle, const
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto face_recognition_task = context->__tasks["face_recognition"];
-               auto &result = static_cast<FaceRecognitionResult &>(face_recognition_task->getOutput());
+               auto &result =
+                               static_cast<FaceRecognitionResult &>(machine_learning_native_get_result(handle, "face_recognition"));
 
                if (static_cast<size_t>(index) >= result.labels.size())
                        throw InvalidParameter("A given index is out of boundary.");
index 777719e..208624e 100644 (file)
@@ -48,13 +48,17 @@ int mv_image_classification_create(mv_image_classification_h *out_handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
+       mv_image_classification_h ctx = nullptr;
+
        try {
-               machine_learning_native_create(TASK_NAME, new ImageClassificationAdapter(), out_handle);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, TASK_NAME, new ImageClassificationAdapter());
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
        }
 
+       *out_handle = ctx;
        LOGD("image classification handle [%p] has been created", *out_handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();
@@ -69,7 +73,7 @@ int mv_image_classification_destroy(mv_image_classification_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       machine_learning_native_destory(handle, TASK_NAME);
+       machine_learning_native_destory(handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
index cca0757..40d2ae4 100644 (file)
@@ -49,13 +49,17 @@ int mv_facial_landmark_create(mv_facial_landmark_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
+       mv_facial_landmark_h ctx = nullptr;
+
        try {
-               machine_learning_native_create(TASK_NAME, new FacialLandmarkAdapter(), handle);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, TASK_NAME, new FacialLandmarkAdapter());
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
        }
 
+       *handle = ctx;
        MEDIA_VISION_FUNCTION_LEAVE();
 
        return MEDIA_VISION_ERROR_NONE;
@@ -68,7 +72,7 @@ int mv_facial_landmark_destroy(mv_facial_landmark_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       machine_learning_native_destory(handle, TASK_NAME);
+       machine_learning_native_destory(handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
index c899e03..09dc5ed 100644 (file)
@@ -49,13 +49,17 @@ int mv_pose_landmark_create(mv_pose_landmark_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
+       mv_pose_landmark_h ctx = nullptr;
+
        try {
-               machine_learning_native_create(TASK_NAME, new PoseLandmarkAdapter(), handle);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, TASK_NAME, new PoseLandmarkAdapter());
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
        }
 
+       *handle = ctx;
        MEDIA_VISION_FUNCTION_LEAVE();
 
        return MEDIA_VISION_ERROR_NONE;
@@ -68,7 +72,7 @@ int mv_pose_landmark_destroy(mv_pose_landmark_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       machine_learning_native_destory(handle, TASK_NAME);
+       machine_learning_native_destory(handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
index 821ba2f..f5470d1 100644 (file)
@@ -51,12 +51,16 @@ int mv_face_detection_create(mv_face_detection_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
+       mv_face_detection_h ctx = nullptr;
+
        try {
-               machine_learning_native_create(TASK_NAME, new FaceDetectionAdapter(), handle);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, TASK_NAME, new FaceDetectionAdapter());
        } catch (const BaseException &e) {
                return e.getError();
        }
 
+       *handle = ctx;
        MEDIA_VISION_FUNCTION_LEAVE();
 
        return MEDIA_VISION_ERROR_NONE;
@@ -69,7 +73,7 @@ int mv_face_detection_destroy(mv_face_detection_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       machine_learning_native_destory(handle, TASK_NAME);
+       machine_learning_native_destory(handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
index 35050af..4fd9b9d 100644 (file)
@@ -51,12 +51,16 @@ int mv_object_detection_create(mv_object_detection_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
+       mv_object_detection_h ctx = nullptr;
+
        try {
-               machine_learning_native_create(TASK_NAME, new ObjectDetectionAdapter(), handle);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, TASK_NAME, new ObjectDetectionAdapter());
        } catch (const BaseException &e) {
                return e.getError();
        }
 
+       *handle = ctx;
        MEDIA_VISION_FUNCTION_LEAVE();
 
        return MEDIA_VISION_ERROR_NONE;
@@ -69,7 +73,7 @@ int mv_object_detection_destroy(mv_object_detection_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       machine_learning_native_destory(handle, TASK_NAME);
+       machine_learning_native_destory(handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();