mv_machine_learning: use MachineLearningNative module for other task groups
authorInki Dae <inki.dae@samsung.com>
Wed, 13 Dec 2023 07:52:23 +0000 (16:52 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 27 Dec 2023 03:22:49 +0000 (12:22 +0900)
[Issue type] : code cleanup

Use MachineLearningNative module for image segmentation and object detection
3d task groups to drop code duplication.

Change-Id: Ief662f216143a1463fd01a6a909e96799c07aff1
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/image_segmentation/src/mv_selfie_segmentation.cpp
mv_machine_learning/object_detection_3d/src/mv_object_detection_3d.cpp

index ecf71b5..5c3c25b 100644 (file)
@@ -21,6 +21,7 @@
 #include "mv_selfie_segmentation_internal.h"
 #include "selfie_segmentation_adapter.h"
 #include "machine_learning_exception.h"
+#include "MachineLearningNative.h"
 #include "image_segmentation_type.h"
 #include "context.h"
 
@@ -31,6 +32,8 @@
 #include <mutex>
 #include <iostream>
 
+#define TASK_NAME "selfie_segmentation"
+
 using namespace std;
 using namespace mediavision::inference;
 using namespace mediavision::common;
@@ -49,20 +52,17 @@ int mv_selfie_segmentation_create(mv_selfie_segmentation_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       ITask *task = nullptr;
+       mv_selfie_segmentation_h ctx = nullptr;
 
        try {
-               context = new Context();
-               task = new ImageSegmentationAdapter();
-               context->__tasks.insert(make_pair("selfie_segmentation", task));
-               *handle = static_cast<mv_selfie_segmentation_h>(context);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, TASK_NAME, new ImageSegmentationAdapter());
        } catch (const BaseException &e) {
-               delete task;
-               delete context;
+               LOGE("%s", e.what());
                return e.getError();
        }
 
+       *handle = ctx;
        MEDIA_VISION_FUNCTION_LEAVE();
 
        return MEDIA_VISION_ERROR_NONE;
@@ -75,12 +75,7 @@ int mv_selfie_segmentation_destroy(mv_selfie_segmentation_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       auto context = static_cast<Context *>(handle);
-
-       for (auto &m : context->__tasks)
-               delete m.second;
-
-       delete context;
+       machine_learning_native_destory(handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
@@ -97,10 +92,7 @@ int mv_selfie_segmentation_set_model(mv_selfie_segmentation_h handle, const char
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               task->setModelInfo(model_file, meta_file, label_file, model_name);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -123,10 +115,7 @@ int mv_selfie_segmentation_set_engine(mv_selfie_segmentation_h handle, const cha
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               task->setEngineInfo(backend_type, device_type);
+               machine_learning_native_set_engine(handle, TASK_NAME, backend_type, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -147,11 +136,7 @@ int mv_selfie_segmentation_get_engine_count(mv_selfie_segmentation_h handle, uns
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               *engine_count = task->getNumberOfEngines();
-               // TODO.
+               machine_learning_native_get_engine_count(handle, TASK_NAME, engine_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -173,11 +158,7 @@ int mv_selfie_segmentation_get_engine_type(mv_selfie_segmentation_h handle, cons
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               *engine_type = (char *) task->getEngineType(engine_index).c_str();
-               // TODO.
+               machine_learning_native_get_engine_type(handle, TASK_NAME, engine_index, engine_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -199,11 +180,7 @@ int mv_selfie_segmentation_get_device_count(mv_selfie_segmentation_h handle, con
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               *device_count = task->getNumberOfDevices(engine_type);
-               // TODO.
+               machine_learning_native_get_device_count(handle, TASK_NAME, engine_type, device_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -226,11 +203,7 @@ int mv_selfie_segmentation_get_device_type(mv_selfie_segmentation_h handle, cons
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               *device_type = (char *) task->getDeviceType(engine_type, device_index).c_str();
-               // TODO.
+               machine_learning_native_get_device_type(handle, TASK_NAME, engine_type, device_index, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -249,10 +222,7 @@ int mv_selfie_segmentation_configure(mv_selfie_segmentation_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               task->configure();
+               machine_learning_native_configure(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -271,10 +241,7 @@ int mv_selfie_segmentation_prepare(mv_selfie_segmentation_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               task->prepare();
+               machine_learning_native_prepare(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -294,12 +261,9 @@ int mv_selfie_segmentation_inference(mv_selfie_segmentation_h handle, mv_source_
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               ImageSegmentationInput input = { .inference_src = source };
+               ImageSegmentationInput input(source);
 
-               task->perform(input);
+               machine_learning_native_inference(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -319,12 +283,9 @@ int mv_selfie_segmentation_inference_async(mv_selfie_segmentation_h handle, mv_s
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
-
-               ImageSegmentationInput input = { source };
+               ImageSegmentationInput input(source);
 
-               task->performAsync(input);
+               machine_learning_native_inference_async(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -348,10 +309,8 @@ int mv_selfie_segmentation_get_result(mv_selfie_segmentation_h handle, unsigned
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("selfie_segmentation");
+               auto &result = static_cast<ImageSegmentationResult &>(machine_learning_native_get_result(handle, TASK_NAME));
 
-               auto &result = static_cast<ImageSegmentationResult &>(task->getOutput());
                *width = result.width;
                *height = result.height;
                *pixel_size = result.pixel_size;
index fe8ff55..ff238ce 100644 (file)
@@ -20,6 +20,7 @@
 #include "mv_object_detection_3d_internal.h"
 #include "object_detection_3d_adapter.h"
 #include "machine_learning_exception.h"
+#include "MachineLearningNative.h"
 #include "object_detection_3d_type.h"
 #include "context.h"
 
@@ -29,6 +30,8 @@
 #include <algorithm>
 #include <mutex>
 
+#define TASK_NAME "object_detection_3d"
+
 using namespace std;
 using namespace mediavision::inference;
 using namespace mediavision::common;
@@ -49,20 +52,17 @@ int mv_object_detection_3d_create(mv_object_detection_3d_h *handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       ITask *task = nullptr;
+       mv_object_detection_3d_h ctx = nullptr;
 
        try {
-               context = new Context();
-               task = new ObjectDetection3dAdapter();
-               context->__tasks.insert(make_pair("object_detection_3d", task));
-               *handle = static_cast<mv_object_detection_3d_h>(context);
+               ctx = machine_learning_native_create();
+               machine_learning_native_add(ctx, TASK_NAME, new ObjectDetection3dAdapter());
        } catch (const BaseException &e) {
-               delete task;
-               delete context;
+               LOGE("%s", e.what());
                return e.getError();
        }
 
+       *handle = ctx;
        MEDIA_VISION_FUNCTION_LEAVE();
 
        return MEDIA_VISION_ERROR_NONE;
@@ -78,12 +78,7 @@ int mv_object_detection_3d_destroy(mv_object_detection_3d_h handle)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       auto context = static_cast<Context *>(handle);
-
-       for (auto &m : context->__tasks)
-               delete m.second;
-
-       delete context;
+       machine_learning_native_destory(handle);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
@@ -106,10 +101,7 @@ int mv_object_detection_3d_set_model(mv_object_detection_3d_h handle, const char
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               task->setModelInfo(model_file, meta_file, label_file, model_name);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file, model_name);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -134,10 +126,7 @@ int mv_object_detection_3d_set_engine(mv_object_detection_3d_h handle, const cha
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               task->setEngineInfo(backend_type, device_type);
+               machine_learning_native_set_engine(handle, TASK_NAME, backend_type, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -160,11 +149,7 @@ int mv_object_detection_3d_get_engine_count(mv_object_detection_3d_h handle, uns
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               *engine_count = task->getNumberOfEngines();
-               // TODO.
+               machine_learning_native_get_engine_count(handle, TASK_NAME, engine_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -188,11 +173,7 @@ int mv_object_detection_3d_get_engine_type(mv_object_detection_3d_h handle, cons
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               *engine_type = (char *) task->getEngineType(engine_index).c_str();
-               // TODO.
+               machine_learning_native_get_engine_type(handle, TASK_NAME, engine_index, engine_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -216,11 +197,7 @@ int mv_object_detection_3d_get_device_count(mv_object_detection_3d_h handle, con
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               *device_count = task->getNumberOfDevices(engine_type);
-               // TODO.
+               machine_learning_native_get_device_count(handle, TASK_NAME, engine_type, device_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -246,11 +223,7 @@ int mv_object_detection_3d_get_device_type(mv_object_detection_3d_h handle, cons
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               *device_type = (char *) task->getDeviceType(engine_type, device_index).c_str();
-               // TODO.
+               machine_learning_native_get_device_type(handle, TASK_NAME, engine_type, device_index, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -272,10 +245,7 @@ int mv_object_detection_3d_configure(mv_object_detection_3d_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               task->configure();
+               machine_learning_native_configure(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -297,10 +267,7 @@ int mv_object_detection_3d_prepare(mv_object_detection_3d_h handle)
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               task->prepare();
+               machine_learning_native_prepare(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -323,12 +290,9 @@ int mv_object_detection_3d_inference(mv_object_detection_3d_h handle, mv_source_
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
+               ObjectDetection3dInput input(source);
 
-               ObjectDetection3dInput input = { source };
-
-               task->perform(input);
+               machine_learning_native_inference(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -351,10 +315,7 @@ int mv_object_detection_3d_get_probability(mv_object_detection_3d_h handle, unsi
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               auto &result = static_cast<ObjectDetection3dResult &>(task->getOutput());
+               auto &result = static_cast<ObjectDetection3dResult &>(machine_learning_native_get_result(handle, TASK_NAME));
 
                *out_probability = result.probability;
        } catch (const BaseException &e) {
@@ -379,10 +340,7 @@ int mv_object_detection_3d_get_num_of_points(mv_object_detection_3d_h handle, un
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               auto &result = static_cast<ObjectDetection3dResult &>(task->getOutput());
+               auto &result = static_cast<ObjectDetection3dResult &>(machine_learning_native_get_result(handle, TASK_NAME));
 
                *out_num_of_points = result.number_of_points;
        } catch (const BaseException &e) {
@@ -408,10 +366,7 @@ int mv_object_detection_3d_get_points(mv_object_detection_3d_h handle, unsigned
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               Context *context = static_cast<Context *>(handle);
-               auto task = context->__tasks.at("object_detection_3d");
-
-               auto &result = static_cast<ObjectDetection3dResult &>(task->getOutput());
+               auto &result = static_cast<ObjectDetection3dResult &>(machine_learning_native_get_result(handle, TASK_NAME));
 
                *out_x = result.x_vec.data();
                *out_y = result.y_vec.data();