test: add test case for gaze tracking task group
authorInki Dae <inki.dae@samsung.com>
Mon, 15 Jul 2024 07:31:47 +0000 (16:31 +0900)
committerInki Dae <inki.dae@samsung.com>
Tue, 24 Sep 2024 02:07:26 +0000 (11:07 +0900)
Change-Id: I7df3dd4160b2e7a33cc98f4d0caf60ba00c658fd
Signed-off-by: Inki Dae <inki.dae@samsung.com>
test/CMakeLists.txt
test/testsuites/machine_learning/CMakeLists.txt
test/testsuites/machine_learning/gaze_tracking/CMakeLists.txt [new file with mode: 0644]
test/testsuites/machine_learning/gaze_tracking/test_gaze_tracking.cpp [new file with mode: 0644]
test/testsuites/machine_learning/gaze_tracking/test_gaze_tracking_async.cpp [new file with mode: 0644]

index 8707e5dec6947340728adc9889d20af3490ef264..7870e4e5b93d7162b6795a9aa24c53083ac9c5fa 100644 (file)
@@ -52,6 +52,8 @@ if (${ENABLE_ML_IMAGE_SEGMENTATION})
   include(testsuites/machine_learning/image_segmentation/CMakeLists.txt)
 endif()
 
+include(testsuites/machine_learning/gaze_tracking/CMakeLists.txt)
+
 if (${BUILD_DEPTH_STREAM_TESTSUITE})
   set(SRC_FILES ${SRC_FILES} testsuites/mv3d/test_3d.cpp)
 endif()
@@ -95,6 +97,8 @@ if (${ENABLE_ML_IMAGE_SEGMENTATION})
 target_link_libraries(${PROJECT_NAME} ${MV_IMAGE_SEGMENTATION_LIB_NAME})
 endif()
 
+target_link_libraries(${PROJECT_NAME} ${MV_GAZE_TRACKING_LIB_NAME})
+
 if (${BUILD_DEPTH_STREAM_TESTSUITE})
 target_link_libraries(${PROJECT_NAME} mv_3d)
 endif()
index eb762325eccf73d09b2204bf97155d92ac906fae..ac065e9129f4010bd75f139d318d24c0816346d4 100644 (file)
@@ -27,3 +27,5 @@ if (${ENABLE_ML_IMAGE_SEGMENTATION})
     message("Enabled machine learning image segmentation test case.")
     add_subdirectory(${PROJECT_SOURCE_DIR}/image_segmentation)
 endif()
+message("Enabled machine learning gaze tracking test case.")
+add_subdirectory(${PROJECT_SOURCE_DIR}/gaze_tracking)
diff --git a/test/testsuites/machine_learning/gaze_tracking/CMakeLists.txt b/test/testsuites/machine_learning/gaze_tracking/CMakeLists.txt
new file mode 100644 (file)
index 0000000..6b4dbdc
--- /dev/null
@@ -0,0 +1,5 @@
+set(SRC_FILES
+    ${SRC_FILES}
+    testsuites/machine_learning/gaze_tracking/test_gaze_tracking.cpp
+    testsuites/machine_learning/gaze_tracking/test_gaze_tracking_async.cpp
+)
\ No newline at end of file
diff --git a/test/testsuites/machine_learning/gaze_tracking/test_gaze_tracking.cpp b/test/testsuites/machine_learning/gaze_tracking/test_gaze_tracking.cpp
new file mode 100644 (file)
index 0000000..49e8ef5
--- /dev/null
@@ -0,0 +1,135 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <iostream>
+#include <string.h>
+
+#include "gtest/gtest.h"
+
+#include "../task_model_info.hpp"
+#include "ImageHelper.h"
+#include "mv_gaze_tracking.h"
+#include "mv_gaze_tracking_internal.h"
+
+#define IMG_GAZE TEST_RES_PATH "/res/inference/images/gazeDetection.jpg"
+
+using namespace testing;
+using namespace std;
+
+using namespace MediaVision::Common;
+
+TEST(GazeTrackingTest, GettingAvailableInferenceEnginesInfoShouldBeOk)
+{
+       mv_gaze_tracking_h handle;
+
+       int ret = mv_gaze_tracking_create(&handle);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       unsigned int engine_count = 0;
+
+       ret = mv_gaze_tracking_get_engine_count(handle, &engine_count);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+       ASSERT_GE(engine_count, 1);
+
+       for (unsigned int engine_idx = 0; engine_idx < engine_count; ++engine_idx) {
+               char *engine_type = nullptr;
+
+               ret = mv_gaze_tracking_get_engine_type(handle, engine_idx, &engine_type);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               unsigned int device_count = 0;
+
+               ret = mv_gaze_tracking_get_device_count(handle, engine_type, &device_count);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+               ASSERT_GE(engine_count, 1);
+
+               for (unsigned int device_idx = 0; device_idx < device_count; ++device_idx) {
+                       char *device_type = nullptr;
+
+                       ret = mv_gaze_tracking_get_device_type(handle, engine_type, device_idx, &device_type);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+               }
+       }
+
+       ret = mv_gaze_tracking_destroy(handle);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+}
+
+TEST(GazeTrackingTest, InferenceShouldBeOk)
+{
+       mv_gaze_tracking_h handle;
+       vector<test_model_input> test_models {
+               {} // If empty then default model will be used.
+       };
+
+       mv_source_h mv_source = NULL;
+       int ret = mv_create_source(&mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       ret = ImageHelper::loadImageToSource(IMG_GAZE, mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       unsigned int answer_idx = 0;
+
+       for (const auto &model : test_models) {
+               ret = mv_gaze_tracking_create(&handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_model(handle, model.model_file.c_str(), model.meta_file.c_str(),
+                                                                                model.label_file.c_str(), NULL);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_engine(handle, "tflite", "cpu");
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_configure(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_prepare(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_inference(handle, mv_source);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               unsigned long frame_number;
+               unsigned int number_of_faces;
+
+               ret = mv_gaze_tracking_get_result_count(handle, &frame_number, &number_of_faces);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               for (unsigned int idx = 0; idx < number_of_faces; ++idx) {
+                       float x, y, yaw, pitch;
+
+                       int ret = mv_gaze_tracking_get_pos(handle, idx, &x, &y);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       ret = mv_gaze_tracking_get_raw_data(handle, idx, &yaw, &pitch);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       cout << "frame number = " << frame_number << " x = " << x << " y = " << y << " yaw = " << yaw
+                                << "  pitch = " << pitch << endl;
+               }
+
+               answer_idx++;
+
+               ret = mv_gaze_tracking_destroy(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+       }
+
+       ret = mv_destroy_source(mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+}
diff --git a/test/testsuites/machine_learning/gaze_tracking/test_gaze_tracking_async.cpp b/test/testsuites/machine_learning/gaze_tracking/test_gaze_tracking_async.cpp
new file mode 100644 (file)
index 0000000..7acc5d8
--- /dev/null
@@ -0,0 +1,171 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <iostream>
+#include <string.h>
+#include <thread>
+
+#include "gtest/gtest.h"
+
+#include "../task_model_info.hpp"
+#include "ImageHelper.h"
+#include "mv_gaze_tracking.h"
+#include "mv_gaze_tracking_internal.h"
+
+#define IMG_GAZE TEST_RES_PATH "/res/inference/images/gazeDetection.jpg"
+#define MAX_INFERENCE_ITERATION 50
+
+using namespace testing;
+using namespace std;
+
+using namespace MediaVision::Common;
+
+void gaze_tracking_callback(void *user_data)
+{
+       mv_gaze_tracking_h handle = static_cast<mv_gaze_tracking_h>(user_data);
+
+       bool is_loop_exit = false;
+
+       while (!is_loop_exit) {
+               unsigned long frame_number;
+               unsigned int number_of_results;
+
+               int ret = mv_gaze_tracking_get_result_count(handle, &frame_number, &number_of_results);
+               if (ret == MEDIA_VISION_ERROR_INVALID_OPERATION)
+                       break;
+
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               for (unsigned int idx = 0; idx < number_of_results; ++idx) {
+                       float x, y, yaw, pitch;
+
+                       int ret = mv_gaze_tracking_get_pos(handle, idx, &x, &y);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       ret = mv_gaze_tracking_get_raw_data(handle, idx, &yaw, &pitch);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       cout << "frame number = " << frame_number << " x = " << x << " y = " << y << " yaw = " << yaw
+                                << "  pitch = " << pitch << endl;
+
+                       if (frame_number > MAX_INFERENCE_ITERATION - 10)
+                               is_loop_exit = true;
+               }
+       }
+}
+
+TEST(GazeTrackingAsyncTest, InferenceShouldBeOk)
+{
+       mv_gaze_tracking_h handle;
+       vector<test_model_input> test_models {
+               {} // If empty then default model will be used.
+       };
+
+       for (auto &model : test_models) {
+               int ret = mv_gaze_tracking_create(&handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_model(handle, model.model_file.c_str(), model.meta_file.c_str(),
+                                                                                model.label_file.c_str(), NULL);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_engine(handle, "tflite", "cpu");
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_configure(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_prepare(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               unique_ptr<thread> thread_handle;
+
+               for (unsigned int iter = 0; iter < MAX_INFERENCE_ITERATION; ++iter) {
+                       mv_source_h mv_source = NULL;
+                       ret = mv_create_source(&mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       ret = ImageHelper::loadImageToSource(IMG_GAZE, mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       ret = mv_gaze_tracking_inference_async(handle, mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       if (iter == 0)
+                               thread_handle = make_unique<thread>(&gaze_tracking_callback, static_cast<void *>(handle));
+
+                       ret = mv_destroy_source(mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+               }
+
+               thread_handle->join();
+
+               ret = mv_gaze_tracking_destroy(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+       }
+}
+
+TEST(GazeTrackingAsyncTest, InferenceShouldBeOkWithDestroyFirst)
+{
+       mv_gaze_tracking_h handle;
+       vector<test_model_input> test_models {
+               {} // If empty then default model will be used.
+       };
+
+       for (auto &model : test_models) {
+               int ret = mv_gaze_tracking_create(&handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_model(handle, model.model_file.c_str(), model.meta_file.c_str(),
+                                                                                model.label_file.c_str(), NULL);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_engine(handle, "tflite", "cpu");
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_configure(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_prepare(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               unique_ptr<thread> thread_handle;
+
+               for (unsigned int iter = 0; iter < MAX_INFERENCE_ITERATION; ++iter) {
+                       mv_source_h mv_source = NULL;
+                       ret = mv_create_source(&mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       ret = ImageHelper::loadImageToSource(IMG_GAZE, mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       ret = mv_gaze_tracking_inference_async(handle, mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       if (iter == 0)
+                               thread_handle = make_unique<thread>(&gaze_tracking_callback, static_cast<void *>(handle));
+
+                       ret = mv_destroy_source(mv_source);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+               }
+
+               ret = mv_gaze_tracking_destroy(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               thread_handle->join();
+       }
+}