mv_machine_learning: introduce MachineLearningNative module
authorInki Dae <inki.dae@samsung.com>
Mon, 27 Nov 2023 01:52:12 +0000 (10:52 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 6 Dec 2023 01:36:46 +0000 (10:36 +0900)
[Issue type] : new feature

Introduce MachineLearningNative module which can be used commonly
for all task groups for context management.

This patch allows us to eliminate the duplicated code within each task's native
API implementation module, where the module name begin with the "mv_" prefix.
By consolidating all task management code into the common module
- MachineLearningNative.cpp - and utilizing this module instead of creating
separate implementations in each task's native API module, we can streamline
development and reduce redundancy.

As a starting point, this patch makes the image classification task group
utilize the common module rather than its own internal code.
The goal of this change is to determine whether or not we can offer a unified
interface for all task groups. If successful, this will allow us to apply
the same native module to other task groups as well.

If our initial tests prove successful, we can move forward with implementing
a task manager that utilizes a inference and training pipeline graph approach
to manage individual tasks.

Change-Id: I4cf4f2a06294acae7ac94d3f562958d2ad3d0770
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/common/include/MachineLearningNative.h [new file with mode: 0644]
mv_machine_learning/common/src/MachineLearningNative.cpp [new file with mode: 0644]
mv_machine_learning/image_classification/src/mv_image_classification.cpp

diff --git a/mv_machine_learning/common/include/MachineLearningNative.h b/mv_machine_learning/common/include/MachineLearningNative.h
new file mode 100644 (file)
index 0000000..a6355f9
--- /dev/null
@@ -0,0 +1,56 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MACHINE_LEARNING_NATIVE_H__
+#define __MACHINE_LEARNING_NATIVE_H__
+
+#include <memory>
+#include <string>
+
+#include "context.h"
+#include "itask.h"
+#include "mv_private.h"
+
+#include "MachineLearningType.h"
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T, typename V>
+void machine_learning_native_create(const std::string &task_name, mediavision::common::ITask<T, V> *task,
+                                                                       void **handle);
+void machine_learning_native_destory(void *handle, const std::string &task_name);
+void machine_learning_native_configure(void *handle, const std::string &task_name);
+void machine_learning_native_prepare(void *handle, const std::string &task_name);
+void machine_learning_native_inference(void *handle, const std::string &task_name, InputBaseType &input);
+void machine_learning_native_inference_async(void *handle, const std::string &task_name, InputBaseType &input);
+OutputBaseType &machine_learning_native_get_result(void *handle, const std::string &task_name);
+void machine_learning_native_set_model(void *handle, const std::string &task_name, const char *model_file,
+                                                                          const char *meta_file, const char *label_file);
+void machine_learning_native_set_engine(void *handle, const std::string &task_name, const char *backend_type,
+                                                                               const char *device_type);
+void machine_learning_native_get_engine_count(void *handle, const std::string &task_name, unsigned int *engine_count);
+void machine_learning_native_get_engine_type(void *handle, const std::string &task_name,
+                                                                                        const unsigned int engine_index, char **engine_type);
+void machine_learning_native_get_device_count(void *handle, const std::string &task_name, const char *engine_type,
+                                                                                         unsigned int *device_count);
+void machine_learning_native_get_device_type(void *handle, const std::string &task_name, const char *engine_type,
+                                                                                        const unsigned int device_index, char **device_type);
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
diff --git a/mv_machine_learning/common/src/MachineLearningNative.cpp b/mv_machine_learning/common/src/MachineLearningNative.cpp
new file mode 100644 (file)
index 0000000..666eae2
--- /dev/null
@@ -0,0 +1,146 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mv_private.h"
+
+#include "MachineLearningNative.h"
+#include "machine_learning_exception.h"
+
+using namespace std;
+using namespace mediavision::common;
+using namespace mediavision::machine_learning;
+using namespace mediavision::machine_learning::exception;
+
+using MachineLearningTask = ITask<InputBaseType, OutputBaseType>;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+inline MachineLearningTask *get_task(void *handle, const std::string &name)
+{
+       auto context = static_cast<Context *>(handle);
+
+       return static_cast<MachineLearningTask *>(context->__tasks.at(name));
+}
+
+template<typename T, typename V>
+void machine_learning_native_create(const string &task_name, ITask<T, V> *task, void **handle)
+{
+       Context *context = new Context();
+
+       context->__tasks.insert(make_pair(task_name, task));
+       *handle = static_cast<void *>(context);
+}
+
+template void machine_learning_native_create<InputBaseType, OutputBaseType>(const string &task_name,
+                                                                                                                                                       MachineLearningTask *task, void **handle);
+
+void machine_learning_native_destory(void *handle, const string &task_name)
+{
+       auto context = static_cast<Context *>(handle);
+
+       for (auto &m : context->__tasks)
+               delete static_cast<MachineLearningTask *>(m.second);
+
+       delete context;
+}
+
+void machine_learning_native_configure(void *handle, const string &task_name)
+{
+       auto task = get_task(handle, task_name);
+
+       task->configure();
+}
+
+void machine_learning_native_prepare(void *handle, const string &task_name)
+{
+       auto task = get_task(handle, task_name);
+
+       task->prepare();
+}
+
+void machine_learning_native_inference(void *handle, const string &task_name, InputBaseType &input)
+{
+       auto task = get_task(handle, task_name);
+
+       task->setInput(input);
+       task->perform();
+}
+
+void machine_learning_native_inference_async(void *handle, const string &task_name, InputBaseType &input)
+{
+       auto task = get_task(handle, task_name);
+
+       task->performAsync(input);
+}
+
+OutputBaseType &machine_learning_native_get_result(void *handle, const string &task_name)
+{
+       auto task = get_task(handle, task_name);
+
+       return task->getOutput();
+}
+
+void machine_learning_native_set_model(void *handle, const string &task_name, const char *model_file,
+                                                                          const char *meta_file, const char *label_file)
+{
+       auto task = get_task(handle, task_name);
+
+       task->setModelInfo(model_file, meta_file, label_file);
+}
+
+void machine_learning_native_set_engine(void *handle, const string &task_name, const char *backend_type,
+                                                                               const char *device_type)
+{
+       auto task = get_task(handle, task_name);
+
+       task->setEngineInfo(backend_type, device_type);
+}
+
+void machine_learning_native_get_engine_count(void *handle, const string &task_name, unsigned int *engine_count)
+{
+       auto task = get_task(handle, task_name);
+
+       task->getNumberOfEngines(engine_count);
+}
+
+void machine_learning_native_get_engine_type(void *handle, const string &task_name, const unsigned int engine_index,
+                                                                                        char **engine_type)
+{
+       auto task = get_task(handle, task_name);
+
+       task->getEngineType(engine_index, engine_type);
+}
+
+void machine_learning_native_get_device_count(void *handle, const string &task_name, const char *engine_type,
+                                                                                         unsigned int *device_count)
+{
+       auto task = get_task(handle, task_name);
+
+       task->getNumberOfDevices(engine_type, device_count);
+}
+
+void machine_learning_native_get_device_type(void *handle, const string &task_name, const char *engine_type,
+                                                                                        const unsigned int device_index, char **device_type)
+{
+       auto task = get_task(handle, task_name);
+
+       task->getDeviceType(engine_type, device_index, device_type);
+}
+
+}
+}
index 9908a0b..baa4ada 100644 (file)
@@ -19,6 +19,7 @@
 #include "mv_image_classification_internal.h"
 #include "image_classification_adapter.h"
 #include "machine_learning_exception.h"
+#include "MachineLearningNative.h"
 #include "image_classification_type.h"
 #include "context.h"
 
 #include <string>
 #include <algorithm>
 
+#define TASK_NAME "image_classification"
+
 using namespace std;
 using namespace mediavision::inference;
 using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using ImageClassificationTask = ITask<InputBaseType, OutputBaseType>;
+using ImageClassificationTask = ImageClassificationAdapter<InputBaseType, OutputBaseType>;
 
 int mv_image_classification_create(mv_image_classification_h *out_handle)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_NULL_ARG_CHECK(out_handle);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       Context *context = nullptr;
-       ImageClassificationTask *task = nullptr;
-
        try {
-               context = new Context();
-               task = new ImageClassificationAdapter<InputBaseType, OutputBaseType>();
-
-               context->__tasks.insert(make_pair("image_classification", task));
-               *out_handle = static_cast<mv_image_classification_h>(context);
+               machine_learning_native_create<InputBaseType, OutputBaseType>(TASK_NAME, new ImageClassificationTask(),
+                                                                                                                                         out_handle);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
-               delete task;
-               delete context;
                return e.getError();
        }
 
@@ -67,17 +62,12 @@ int mv_image_classification_create(mv_image_classification_h *out_handle)
 
 int mv_image_classification_destroy(mv_image_classification_h handle)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
-       MEDIA_VISION_INSTANCE_CHECK(handle);
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
+       MEDIA_VISION_NULL_ARG_CHECK(handle);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       auto context = static_cast<Context *>(handle);
-
-       for (auto &m : context->__tasks)
-               delete static_cast<ImageClassificationTask *>(m.second);
-
-       delete context;
+       machine_learning_native_destory(handle, TASK_NAME);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
@@ -86,19 +76,13 @@ int mv_image_classification_destroy(mv_image_classification_h handle)
 
 int mv_image_classification_configure(mv_image_classification_h handle)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-               if (!task) {
-                       return MEDIA_VISION_ERROR_INVALID_OPERATION;
-               }
-
-               task->configure();
+               machine_learning_native_configure(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -111,16 +95,13 @@ int mv_image_classification_configure(mv_image_classification_h handle)
 
 int mv_image_classification_prepare(mv_image_classification_h handle)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               task->prepare();
+               machine_learning_native_prepare(handle, TASK_NAME);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -140,13 +121,9 @@ int mv_image_classification_inference(mv_image_classification_h handle, mv_sourc
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
                ImageClassificationInput input(source);
 
-               task->setInput(input);
-               task->perform();
+               machine_learning_native_inference(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -166,12 +143,9 @@ int mv_image_classification_inference_async(mv_image_classification_h handle, mv
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
                ImageClassificationInput input(source);
 
-               task->performAsync(input);
+               machine_learning_native_inference_async(handle, TASK_NAME, input);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -184,18 +158,15 @@ int mv_image_classification_inference_async(mv_image_classification_h handle, mv
 
 int mv_image_classification_get_label(mv_image_classification_h handle, const char **out_label)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
-
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
        MEDIA_VISION_NULL_ARG_CHECK(out_label);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               const auto &result = static_cast<ImageClassificationResult &>(task->getOutput());
+               const auto &result =
+                               static_cast<ImageClassificationResult &>(machine_learning_native_get_result(handle, TASK_NAME));
 
                *out_label = result.label.c_str();
        } catch (const BaseException &e) {
@@ -211,8 +182,7 @@ int mv_image_classification_get_label(mv_image_classification_h handle, const ch
 int mv_image_classification_set_model(mv_image_classification_h handle, const char *model_file, const char *meta_file,
                                                                          const char *label_file)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
-
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
        MEDIA_VISION_NULL_ARG_CHECK(model_file);
        MEDIA_VISION_NULL_ARG_CHECK(meta_file);
@@ -221,10 +191,7 @@ int mv_image_classification_set_model(mv_image_classification_h handle, const ch
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               task->setModelInfo(model_file, meta_file, label_file);
+               machine_learning_native_set_model(handle, TASK_NAME, model_file, meta_file, label_file);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -238,8 +205,7 @@ int mv_image_classification_set_model(mv_image_classification_h handle, const ch
 int mv_image_classification_set_engine(mv_image_classification_h handle, const char *backend_type,
                                                                           const char *device_type)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
-
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
        MEDIA_VISION_NULL_ARG_CHECK(backend_type);
        MEDIA_VISION_NULL_ARG_CHECK(device_type);
@@ -247,10 +213,7 @@ int mv_image_classification_set_engine(mv_image_classification_h handle, const c
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               task->setEngineInfo(backend_type, device_type);
+               machine_learning_native_set_engine(handle, TASK_NAME, backend_type, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -263,19 +226,14 @@ int mv_image_classification_set_engine(mv_image_classification_h handle, const c
 
 int mv_image_classification_get_engine_count(mv_image_classification_h handle, unsigned int *engine_count)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
-
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
        MEDIA_VISION_NULL_ARG_CHECK(engine_count);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               task->getNumberOfEngines(engine_count);
-               // TODO.
+               machine_learning_native_get_engine_count(handle, TASK_NAME, engine_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -289,19 +247,14 @@ int mv_image_classification_get_engine_count(mv_image_classification_h handle, u
 int mv_image_classification_get_engine_type(mv_image_classification_h handle, const unsigned int engine_index,
                                                                                        char **engine_type)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
-
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
        MEDIA_VISION_NULL_ARG_CHECK(engine_type);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               task->getEngineType(engine_index, engine_type);
-               // TODO.
+               machine_learning_native_get_engine_type(handle, TASK_NAME, engine_index, engine_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -315,19 +268,14 @@ int mv_image_classification_get_engine_type(mv_image_classification_h handle, co
 int mv_image_classification_get_device_count(mv_image_classification_h handle, const char *engine_type,
                                                                                         unsigned int *device_count)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
-
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
        MEDIA_VISION_NULL_ARG_CHECK(device_count);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               task->getNumberOfDevices(engine_type, device_count);
-               // TODO.
+               machine_learning_native_get_device_count(handle, TASK_NAME, engine_type, device_count);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -341,8 +289,7 @@ int mv_image_classification_get_device_count(mv_image_classification_h handle, c
 int mv_image_classification_get_device_type(mv_image_classification_h handle, const char *engine_type,
                                                                                        const unsigned int device_index, char **device_type)
 {
-       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
-
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
        MEDIA_VISION_INSTANCE_CHECK(handle);
        MEDIA_VISION_NULL_ARG_CHECK(engine_type);
        MEDIA_VISION_NULL_ARG_CHECK(device_type);
@@ -350,11 +297,7 @@ int mv_image_classification_get_device_type(mv_image_classification_h handle, co
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               auto context = static_cast<Context *>(handle);
-               auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
-
-               task->getDeviceType(engine_type, device_index, device_type);
-               // TODO.
+               machine_learning_native_get_device_type(handle, TASK_NAME, engine_type, device_index, device_type);
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();