mv_machine_learning: drop code duplication from landmark detection
authorInki Dae <inki.dae@samsung.com>
Wed, 29 Nov 2023 07:46:28 +0000 (16:46 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 6 Dec 2023 01:36:46 +0000 (10:36 +0900)
[Issue type] : code cleanup

Drop code duplucation from the landmark detection task group by making the
this task group use MachineLearningNative module instead of internal code
for context management.

Change-Id: I990edccfe92cde801e665f9fc43f9f566f33c124
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/landmark_detection/src/mv_facial_landmark.cpp
mv_machine_learning/landmark_detection/src/mv_pose_landmark.cpp

index 34bf817..8175db3 100644 (file)
@@ -19,6 +19,7 @@
 #include "mv_facial_landmark_internal.h"
 #include "facial_landmark_adapter.h"
 #include "machine_learning_exception.h"
+#include "MachineLearningNative.h"
 #include "landmark_detection_type.h"
 #include "context.h"
 
@@ -26,8 +27,8 @@
 #include <unistd.h>
 #include <string>
 #include <algorithm>
-#include <mutex>
-#include <iostream>
+
+#define TASK_NAME "facial_landmark"
 
 using namespace std;
 using namespace mediavision::inference;
@@ -35,7 +36,7 @@ using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using LandmarkDetectionTask = ITask<InputBaseType, OutputBaseType>;
+using LandmarkDetectionTask = FacialLandmarkAdapter<InputBaseType, OutputBaseType>;
 
 int mv_facial_landmark_create(mv_facial_landmark_h *handle)
 {
@@ -44,17 +45,10 @@ int mv_facial_landmark_create(mv_facial_landmark_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       LandmarkDetectionTask *task = nullptr;
-
        try {
-               context = new Context();
-               task = new FacialLandmarkAdapter<InputBaseType, OutputBaseType>();
-               context->__tasks.insert(make_pair("facial_landmark", task));
-               *handle = static_cast<mv_facial_landmark_h>(context);
+               machine_learning_native_create<InputBaseType, OutputBaseType>(TASK_NAME, new LandmarkDetectionTask(), handle);
        } catch (const BaseException &e) {
-               delete task;
-               delete context;
+               LOGE("%s", e.what());
                return e.getError();
        }
 
@@ -70,12 +64,7 @@ int mv_facial_landmark_destroy(mv_facial_landmark_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       auto context = static_cast<Context *>(handle);
-
-       for (auto &m : context->__tasks)
-               delete static_cast<LandmarkDetectionTask *>(m.second);
-
-       delete context;
+       machine_learning_native_destory(handle, TASK_NAME);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
@@ -96,10 +85,7 @@ int mv_facial_landmark_set_model(mv_facial_landmark_h handle, const char *model_
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->setModelInfo(model_file, meta_file, label_file, model_name);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -121,10 +107,7 @@ int mv_facial_landmark_set_engine(mv_facial_landmark_h handle, const char *backe
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->setEngineInfo(backend_type, device_type);
+               machine_learning_native_set_engine(handle, TASK_NAME, backend_type, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -145,11 +128,7 @@ int mv_facial_landmark_get_engine_count(mv_facial_landmark_h handle, unsigned in
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->getNumberOfEngines(engine_count);
-               // TODO.
+               machine_learning_native_get_engine_count(handle, TASK_NAME, engine_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -170,11 +149,7 @@ int mv_facial_landmark_get_engine_type(mv_facial_landmark_h handle, const unsign
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->getEngineType(engine_index, engine_type);
-               // TODO.
+               machine_learning_native_get_engine_type(handle, TASK_NAME, engine_index, engine_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -196,11 +171,7 @@ int mv_facial_landmark_get_device_count(mv_facial_landmark_h handle, const char
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->getNumberOfDevices(engine_type, device_count);
-               // TODO.
+               machine_learning_native_get_device_count(handle, TASK_NAME, engine_type, device_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -223,11 +194,7 @@ int mv_facial_landmark_get_device_type(mv_facial_landmark_h handle, const char *
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->getDeviceType(engine_type, device_index, device_type);
-               // TODO.
+               machine_learning_native_get_device_type(handle, TASK_NAME, engine_type, device_index, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -246,10 +213,7 @@ int mv_facial_landmark_configure(mv_facial_landmark_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->configure();
+               machine_learning_native_configure(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -268,10 +232,7 @@ int mv_facial_landmark_prepare(mv_facial_landmark_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               task->prepare();
+               machine_learning_native_prepare(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -291,13 +252,9 @@ int mv_facial_landmark_inference(mv_facial_landmark_h handle, mv_source_h source
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
                LandmarkDetectionInput input(source);
 
-               task->setInput(input);
-               task->perform();
+               machine_learning_native_inference(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -317,12 +274,9 @@ int mv_facial_landmark_inference_async(mv_facial_landmark_h handle, mv_source_h
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
                LandmarkDetectionInput input(source);
 
-               task->performAsync(input);
+               machine_learning_native_inference_async(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -345,10 +299,7 @@ int mv_facial_landmark_get_positions(mv_facial_landmark_h handle, unsigned int *
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("facial_landmark"));
-
-               auto &result = static_cast<LandmarkDetectionResult &>(task->getOutput());
+               auto &result = static_cast<LandmarkDetectionResult &>(machine_learning_native_get_result(handle, TASK_NAME));
                *number_of_landmarks = result.number_of_landmarks;
                *pos_x = result.x_pos.data();
                *pos_y = result.y_pos.data();
index 72f3aa4..f440e06 100644 (file)
@@ -19,6 +19,7 @@
 #include "mv_pose_landmark_internal.h"
 #include "pose_landmark_adapter.h"
 #include "machine_learning_exception.h"
+#include "MachineLearningNative.h"
 #include "landmark_detection_type.h"
 #include "context.h"
 
@@ -26,8 +27,8 @@
 #include <unistd.h>
 #include <string>
 #include <algorithm>
-#include <mutex>
-#include <iostream>
+
+#define TASK_NAME "pose_landmark"
 
 using namespace std;
 using namespace mediavision::inference;
@@ -35,7 +36,7 @@ using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using LandmarkDetectionTask = ITask<InputBaseType, OutputBaseType>;
+using LandmarkDetectionTask = PoseLandmarkAdapter<InputBaseType, OutputBaseType>;
 
 int mv_pose_landmark_create(mv_pose_landmark_h *handle)
 {
@@ -44,17 +45,10 @@ int mv_pose_landmark_create(mv_pose_landmark_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       LandmarkDetectionTask *task = nullptr;
-
        try {
-               context = new Context();
-               task = new PoseLandmarkAdapter<InputBaseType, OutputBaseType>();
-               context->__tasks.insert(make_pair("pose_landmark", task));
-               *handle = static_cast<mv_pose_landmark_h>(context);
+               machine_learning_native_create<InputBaseType, OutputBaseType>(TASK_NAME, new LandmarkDetectionTask(), handle);
        } catch (const BaseException &e) {
-               delete task;
-               delete context;
+               LOGE("%s", e.what());
                return e.getError();
        }
 
@@ -70,12 +64,7 @@ int mv_pose_landmark_destroy(mv_pose_landmark_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       auto context = static_cast<Context *>(handle);
-
-       for (auto &m : context->__tasks)
-               delete static_cast<LandmarkDetectionTask *>(m.second);
-
-       delete context;
+       machine_learning_native_destory(handle, TASK_NAME);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
@@ -96,10 +85,7 @@ int mv_pose_landmark_set_model(mv_pose_landmark_h handle, const char *model_name
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->setModelInfo(model_file, meta_file, label_file, model_name);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -121,10 +107,7 @@ int mv_pose_landmark_set_engine(mv_pose_landmark_h handle, const char *backend_t
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->setEngineInfo(backend_type, device_type);
+               machine_learning_native_set_engine(handle, TASK_NAME, backend_type, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -145,11 +128,7 @@ int mv_pose_landmark_get_engine_count(mv_pose_landmark_h handle, unsigned int *e
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->getNumberOfEngines(engine_count);
-               // TODO.
+               machine_learning_native_get_engine_count(handle, TASK_NAME, engine_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -170,11 +149,7 @@ int mv_pose_landmark_get_engine_type(mv_pose_landmark_h handle, const unsigned i
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->getEngineType(engine_index, engine_type);
-               // TODO.
+               machine_learning_native_get_engine_type(handle, TASK_NAME, engine_index, engine_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -195,11 +170,7 @@ int mv_pose_landmark_get_device_count(mv_pose_landmark_h handle, const char *eng
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->getNumberOfDevices(engine_type, device_count);
-               // TODO.
+               machine_learning_native_get_device_count(handle, TASK_NAME, engine_type, device_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -222,11 +193,7 @@ int mv_pose_landmark_get_device_type(mv_pose_landmark_h handle, const char *engi
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->getDeviceType(engine_type, device_index, device_type);
-               // TODO.
+               machine_learning_native_get_device_type(handle, TASK_NAME, engine_type, device_index, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -245,10 +212,7 @@ int mv_pose_landmark_configure(mv_pose_landmark_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->configure();
+               machine_learning_native_configure(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -267,10 +231,7 @@ int mv_pose_landmark_prepare(mv_pose_landmark_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               task->prepare();
+               machine_learning_native_prepare(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -290,13 +251,9 @@ int mv_pose_landmark_inference(mv_pose_landmark_h handle, mv_source_h source)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
                LandmarkDetectionInput input(source);
 
-               task->setInput(input);
-               task->perform();
+               machine_learning_native_inference(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -316,12 +273,9 @@ int mv_pose_landmark_inference_async(mv_pose_landmark_h handle, mv_source_h sour
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
                LandmarkDetectionInput input(source);
 
-               task->performAsync(input);
+               machine_learning_native_inference_async(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -344,10 +298,7 @@ int mv_pose_landmark_get_pos(mv_pose_landmark_h handle, unsigned int *number_of_
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<LandmarkDetectionTask *>(context->__tasks.at("pose_landmark"));
-
-               auto &result = static_cast<LandmarkDetectionResult &>(task->getOutput());
+               auto &result = static_cast<LandmarkDetectionResult &>(machine_learning_native_get_result(handle, TASK_NAME));
                *number_of_landmarks = result.number_of_landmarks;
                *pos_x = result.x_pos.data();
                *pos_y = result.y_pos.data();