mv_machine_learning: add L2CSNet based on mobilenetv2 67/315367/4
authorTae-Young Chung <ty83.chung@samsung.com>
Tue, 26 Nov 2024 08:51:03 +0000 (17:51 +0900)
committerTae-Young Chung <ty83.chung@samsung.com>
Mon, 2 Dec 2024 08:00:45 +0000 (17:00 +0900)
This patch is test with a test image and the model
the image: https://github.sec.samsung.net/k-son/mv_test_res/pull/15
the model: https://github.sec.samsung.net/tizen-vault/open_model_zoo/pull/12

Change-Id: I3ec58923e387a4d283ee508823d49fe3dd785e9e
Signed-off-by: Tae-Young Chung <ty83.chung@samsung.com>
mv_machine_learning/gaze_tracking/include/L2CSNetMobileNetV2.h [new file with mode: 0644]
mv_machine_learning/gaze_tracking/include/gaze_tracking_type.h
mv_machine_learning/gaze_tracking/src/GazeTrackingAdapter.cpp
mv_machine_learning/gaze_tracking/src/L2CSNetMobileNetV2.cpp [new file with mode: 0644]
mv_machine_learning/gaze_tracking/src/mv_gaze_tracking.cpp
test/testsuites/machine_learning/gaze_tracking/test_gaze_tracking.cpp

diff --git a/mv_machine_learning/gaze_tracking/include/L2CSNetMobileNetV2.h b/mv_machine_learning/gaze_tracking/include/L2CSNetMobileNetV2.h
new file mode 100644 (file)
index 0000000..ff51c57
--- /dev/null
@@ -0,0 +1,56 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __L2CSNET_MOBILENET_V2_H__
+#define __L2CSNET_MOBILENET_V2_H__
+
+#include "mv_private.h"
+#include <memory>
+#include <mv_common.h>
+#include <string>
+
+#include "GazeTracking.h"
+#include <mv_inference_type.h>
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T> class L2CSNetMobileNetv2 : public GazeTracking<T>
+{
+       using GazeTracking<T>::_config;
+       using GazeTracking<T>::_preprocess;
+
+private:
+       GazeTrackingResult _result;
+
+       int _bins = 28;
+       float _binWidth = 3.f;
+       float _offsetAngle = 42.f;
+       cv::Mat _indexValues;
+
+       float calculateAngle(std::vector<float>& values);
+
+public:
+       L2CSNetMobileNetv2(GazeTrackingTaskType task_type, std::shared_ptr<Config> config);
+       ~L2CSNetMobileNetv2();
+       GazeTrackingResult &result() override;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index 30b442346d1ce5d5b64d8995a8aa89b46f274dc0..89ea2318aa6729b9333433a87f138dce43687bbc 100644 (file)
@@ -50,7 +50,8 @@ enum class GazeTrackingTaskType {
        GZE_TRACKING_TASK_NONE = 0,
        GZE_L2CS_NET,
        GZE_GENERAL_GAZE_NET,
-       GZE_TINY_TRACKER
+       GZE_TINY_TRACKER,
+       GZE_L2CS_NET_MOBILENETV2
        // TODO
 };
 
index 4b553ce768101a100742bc0f3252f8efeea2d71d..543f462b8360886085bfd4f6b734c04ac7a88c9b 100644 (file)
@@ -19,6 +19,7 @@
 #include "L2CSNet.h"
 #include "MvMlException.h"
 #include "TinyTracker.h"
+#include "L2CSNetMobileNetV2.h"
 #include "gaze_tracking_type.h"
 #include "mv_gaze_tracking_config.h"
 
@@ -61,6 +62,9 @@ template<typename U> void GazeTrackingAdapter::create(GazeTrackingTaskType task_
        case GazeTrackingTaskType::GZE_TINY_TRACKER:
                _gaze_tracking = make_unique<TinyTracker<U> >(task_type, _config);
                break;
+       case GazeTrackingTaskType::GZE_L2CS_NET_MOBILENETV2:
+               _gaze_tracking = make_unique<L2CSNetMobileNetv2<U> >(task_type, _config);
+               break;
        default:
                throw InvalidOperation("Invalid gaze tracking task type.");
        }
@@ -102,6 +106,8 @@ GazeTrackingTaskType GazeTrackingAdapter::convertToTaskType(string model_name)
                return GazeTrackingTaskType::GZE_GENERAL_GAZE_NET;
        if (model_name == "GZE_TINY_TRACKER")
                return GazeTrackingTaskType::GZE_TINY_TRACKER;
+       if (model_name == "GZE_L2CS_NET_MOBILENETV2")
+               return GazeTrackingTaskType::GZE_L2CS_NET_MOBILENETV2;
        // TODO.
 
        throw InvalidParameter("Invalid gaze tracking model name.");
diff --git a/mv_machine_learning/gaze_tracking/src/L2CSNetMobileNetV2.cpp b/mv_machine_learning/gaze_tracking/src/L2CSNetMobileNetV2.cpp
new file mode 100644 (file)
index 0000000..b50dc98
--- /dev/null
@@ -0,0 +1,98 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <map>
+#include <string.h>
+
+#include "L2CSNetMobileNetV2.h"
+#include "MvMlException.h"
+#include "Postprocess.h"
+#include "mv_gaze_tracking_config.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T>
+L2CSNetMobileNetv2<T>::L2CSNetMobileNetv2(GazeTrackingTaskType task_type, std::shared_ptr<Config> config)
+               : GazeTracking<T>(task_type, config), _result()
+{
+       _indexValues = cv::Mat(cv::Size(1, _bins), CV_32F);
+       for (int i = 0; i < _bins; ++i) {
+               _indexValues.at<float>(i) = static_cast<float>(i); // 0 ~ 27
+       }
+}
+
+template<typename T> L2CSNetMobileNetv2<T>::~L2CSNetMobileNetv2()
+{}
+
+template<typename T> float L2CSNetMobileNetv2<T>::calculateAngle(vector<float>& values)
+{
+       cv::Mat valuesMat = cv::Mat(cv::Size(1, values.size()), CV_32F, values.data());
+
+       double maxProb = 0.0;
+       cv::minMaxLoc(valuesMat, nullptr, &maxProb, nullptr, nullptr);
+
+       cv::Mat softmaxProb;
+       cv::exp(valuesMat-static_cast<float>(maxProb), softmaxProb);
+
+       float sum = static_cast<float>(cv::sum(softmaxProb)[0]);
+       softmaxProb /= sum;
+
+       cv::Mat calculatedValues = softmaxProb.mul(_indexValues);
+       return (static_cast<float>(cv::sum(calculatedValues)[0]) * _binWidth - _offsetAngle) * (M_PI / 180.f);
+}
+
+template<typename T> GazeTrackingResult &L2CSNetMobileNetv2<T>::result()
+{
+       // Clear _result object because result() function can be called every time user wants
+       // so make sure to clear existing result data before getting the data again.
+       _result = GazeTrackingResult();
+
+       vector<string> names;
+
+       GazeTracking<T>::getOutputNames(names);
+
+       vector<float> pitches;
+       vector<float> yaws;
+
+       GazeTracking<T>::getOutputTensor(names[0], pitches);
+       GazeTracking<T>::getOutputTensor(names[1], yaws);
+
+       float pitch = calculateAngle(pitches);
+       float yaw = calculateAngle(yaws);
+
+       LOGD("L2CSNetMobileNetv2: yaw: %f, pitch: %f", yaw, pitch);
+
+       _result.frame_number++;
+       _result.number_of_faces = 1;
+       _result.yaws.push_back(yaw);
+       _result.pitches.push_back(pitch);
+
+       return _result;
+}
+
+template class L2CSNetMobileNetv2<float>;
+template class L2CSNetMobileNetv2<unsigned char>;
+template class L2CSNetMobileNetv2<char>;
+}
+}
index e0f503eef723c659e15734eb1f394035c022dd2b..594c111d048704a5144d8abdef2cb59c9f603366 100644 (file)
@@ -327,7 +327,7 @@ int mv_gaze_tracking_get_pos(mv_gaze_tracking_h handle, unsigned int index, floa
        MEDIA_VISION_INSTANCE_CHECK(y);
 
        MEDIA_VISION_FUNCTION_ENTER();
-
+       // NOTE: L2CSNetMobileNetV2 doesn't provide x_pos and y_pos results.
        try {
                auto &result = static_cast<GazeTrackingResult &>(machine_learning_native_get_result_cache(handle, TASK_NAME));
                if (index >= result.number_of_faces) {
index c0815d231a32a2b43fcb7ca6ce3a81934f95eeda..4073007fc78f6908d1fea75686a905c327e9d7e8 100644 (file)
@@ -26,6 +26,7 @@
 #include "mv_gaze_tracking_internal.h"
 
 #define IMG_GAZE TEST_RES_PATH "/res/inference/images/gazeDetection.jpg"
+#define IMG_GAZE_FACE TEST_RES_PATH "/res/inference/images/gazeTracking.jpg"
 
 using namespace testing;
 using namespace std;
@@ -136,3 +137,70 @@ TEST(GazeTrackingTest, InferenceShouldBeOk)
        ret = mv_destroy_source(mv_source);
        ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
 }
+
+TEST(GazeTrackingTest, L2CSNetMobileNetV2InferenceShouldBeOk)
+{
+       mv_gaze_tracking_h handle;
+       vector<test_model_input> test_models {
+               { "gzt_l2cs_mobilenetv2_224x224.tflite", "gzt_l2cs_mobilenetv2_224x224.json", "", "GZE_L2CS_NET_MOBILENETV2" }
+       };
+
+       float eps = 1e-4f;
+       vector<float> answers { 0.0475, -0.2128 }; // yaw, pitch
+
+       mv_source_h mv_source = NULL;
+       int ret = mv_create_source(&mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       ret = ImageHelper::loadImageToSource(IMG_GAZE_FACE, mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       unsigned int answer_idx = 0;
+
+       for (const auto &model : test_models) {
+               ret = mv_gaze_tracking_create(&handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_model(handle, model.model_file.c_str(), model.meta_file.c_str(),
+                                                                                model.label_file.c_str(), model.model_name.c_str());
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_set_engine(handle, "tflite", "cpu");
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_configure(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_prepare(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_gaze_tracking_inference(handle, mv_source);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               unsigned long frame_number;
+               unsigned int number_of_faces;
+
+               ret = mv_gaze_tracking_get_result_count(handle, &frame_number, &number_of_faces);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               for (unsigned int idx = 0; idx < number_of_faces; ++idx) {
+                       float yaw, pitch;
+
+                       ret = mv_gaze_tracking_get_raw_data(handle, idx, &yaw, &pitch);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       ASSERT_GE(answers[0]+eps, yaw);
+                       ASSERT_LE(answers[0]-eps, yaw);
+                       ASSERT_GE(answers[1]+eps, pitch);
+                       ASSERT_LE(answers[1]-eps, pitch);
+               }
+
+               answer_idx++;
+
+               ret = mv_gaze_tracking_destroy(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+       }
+
+       ret = mv_destroy_source(mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+}
\ No newline at end of file