--- /dev/null
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __L2CSNET_MOBILENET_V2_H__
+#define __L2CSNET_MOBILENET_V2_H__
+
+#include "mv_private.h"
+#include <memory>
+#include <mv_common.h>
+#include <string>
+
+#include "GazeTracking.h"
+#include <mv_inference_type.h>
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T> class L2CSNetMobileNetv2 : public GazeTracking<T>
+{
+ using GazeTracking<T>::_config;
+ using GazeTracking<T>::_preprocess;
+
+private:
+ GazeTrackingResult _result;
+
+ int _bins = 28;
+ float _binWidth = 3.f;
+ float _offsetAngle = 42.f;
+ cv::Mat _indexValues;
+
+ float calculateAngle(std::vector<float>& values);
+
+public:
+ L2CSNetMobileNetv2(GazeTrackingTaskType task_type, std::shared_ptr<Config> config);
+ ~L2CSNetMobileNetv2();
+ GazeTrackingResult &result() override;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
GZE_TRACKING_TASK_NONE = 0,
GZE_L2CS_NET,
GZE_GENERAL_GAZE_NET,
- GZE_TINY_TRACKER
+ GZE_TINY_TRACKER,
+ GZE_L2CS_NET_MOBILENETV2
// TODO
};
#include "L2CSNet.h"
#include "MvMlException.h"
#include "TinyTracker.h"
+#include "L2CSNetMobileNetV2.h"
#include "gaze_tracking_type.h"
#include "mv_gaze_tracking_config.h"
case GazeTrackingTaskType::GZE_TINY_TRACKER:
_gaze_tracking = make_unique<TinyTracker<U> >(task_type, _config);
break;
+ case GazeTrackingTaskType::GZE_L2CS_NET_MOBILENETV2:
+ _gaze_tracking = make_unique<L2CSNetMobileNetv2<U> >(task_type, _config);
+ break;
default:
throw InvalidOperation("Invalid gaze tracking task type.");
}
return GazeTrackingTaskType::GZE_GENERAL_GAZE_NET;
if (model_name == "GZE_TINY_TRACKER")
return GazeTrackingTaskType::GZE_TINY_TRACKER;
+ if (model_name == "GZE_L2CS_NET_MOBILENETV2")
+ return GazeTrackingTaskType::GZE_L2CS_NET_MOBILENETV2;
// TODO.
throw InvalidParameter("Invalid gaze tracking model name.");
--- /dev/null
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <map>
+#include <string.h>
+
+#include "L2CSNetMobileNetV2.h"
+#include "MvMlException.h"
+#include "Postprocess.h"
+#include "mv_gaze_tracking_config.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T>
+L2CSNetMobileNetv2<T>::L2CSNetMobileNetv2(GazeTrackingTaskType task_type, std::shared_ptr<Config> config)
+ : GazeTracking<T>(task_type, config), _result()
+{
+ _indexValues = cv::Mat(cv::Size(1, _bins), CV_32F);
+ for (int i = 0; i < _bins; ++i) {
+ _indexValues.at<float>(i) = static_cast<float>(i); // 0 ~ 27
+ }
+}
+
+template<typename T> L2CSNetMobileNetv2<T>::~L2CSNetMobileNetv2()
+{}
+
+template<typename T> float L2CSNetMobileNetv2<T>::calculateAngle(vector<float>& values)
+{
+ cv::Mat valuesMat = cv::Mat(cv::Size(1, values.size()), CV_32F, values.data());
+
+ double maxProb = 0.0;
+ cv::minMaxLoc(valuesMat, nullptr, &maxProb, nullptr, nullptr);
+
+ cv::Mat softmaxProb;
+ cv::exp(valuesMat-static_cast<float>(maxProb), softmaxProb);
+
+ float sum = static_cast<float>(cv::sum(softmaxProb)[0]);
+ softmaxProb /= sum;
+
+ cv::Mat calculatedValues = softmaxProb.mul(_indexValues);
+ return (static_cast<float>(cv::sum(calculatedValues)[0]) * _binWidth - _offsetAngle) * (M_PI / 180.f);
+}
+
+template<typename T> GazeTrackingResult &L2CSNetMobileNetv2<T>::result()
+{
+ // Clear _result object because result() function can be called every time user wants
+ // so make sure to clear existing result data before getting the data again.
+ _result = GazeTrackingResult();
+
+ vector<string> names;
+
+ GazeTracking<T>::getOutputNames(names);
+
+ vector<float> pitches;
+ vector<float> yaws;
+
+ GazeTracking<T>::getOutputTensor(names[0], pitches);
+ GazeTracking<T>::getOutputTensor(names[1], yaws);
+
+ float pitch = calculateAngle(pitches);
+ float yaw = calculateAngle(yaws);
+
+ LOGD("L2CSNetMobileNetv2: yaw: %f, pitch: %f", yaw, pitch);
+
+ _result.frame_number++;
+ _result.number_of_faces = 1;
+ _result.yaws.push_back(yaw);
+ _result.pitches.push_back(pitch);
+
+ return _result;
+}
+
+template class L2CSNetMobileNetv2<float>;
+template class L2CSNetMobileNetv2<unsigned char>;
+template class L2CSNetMobileNetv2<char>;
+}
+}
MEDIA_VISION_INSTANCE_CHECK(y);
MEDIA_VISION_FUNCTION_ENTER();
-
+ // NOTE: L2CSNetMobileNetV2 doesn't provide x_pos and y_pos results.
try {
auto &result = static_cast<GazeTrackingResult &>(machine_learning_native_get_result_cache(handle, TASK_NAME));
if (index >= result.number_of_faces) {
#include "mv_gaze_tracking_internal.h"
#define IMG_GAZE TEST_RES_PATH "/res/inference/images/gazeDetection.jpg"
+#define IMG_GAZE_FACE TEST_RES_PATH "/res/inference/images/gazeTracking.jpg"
using namespace testing;
using namespace std;
ret = mv_destroy_source(mv_source);
ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
}
+
+TEST(GazeTrackingTest, L2CSNetMobileNetV2InferenceShouldBeOk)
+{
+ mv_gaze_tracking_h handle;
+ vector<test_model_input> test_models {
+ { "gzt_l2cs_mobilenetv2_224x224.tflite", "gzt_l2cs_mobilenetv2_224x224.json", "", "GZE_L2CS_NET_MOBILENETV2" }
+ };
+
+ float eps = 1e-4f;
+ vector<float> answers { 0.0475, -0.2128 }; // yaw, pitch
+
+ mv_source_h mv_source = NULL;
+ int ret = mv_create_source(&mv_source);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ ret = ImageHelper::loadImageToSource(IMG_GAZE_FACE, mv_source);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ unsigned int answer_idx = 0;
+
+ for (const auto &model : test_models) {
+ ret = mv_gaze_tracking_create(&handle);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ ret = mv_gaze_tracking_set_model(handle, model.model_file.c_str(), model.meta_file.c_str(),
+ model.label_file.c_str(), model.model_name.c_str());
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ ret = mv_gaze_tracking_set_engine(handle, "tflite", "cpu");
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ ret = mv_gaze_tracking_configure(handle);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ ret = mv_gaze_tracking_prepare(handle);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ ret = mv_gaze_tracking_inference(handle, mv_source);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ unsigned long frame_number;
+ unsigned int number_of_faces;
+
+ ret = mv_gaze_tracking_get_result_count(handle, &frame_number, &number_of_faces);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ for (unsigned int idx = 0; idx < number_of_faces; ++idx) {
+ float yaw, pitch;
+
+ ret = mv_gaze_tracking_get_raw_data(handle, idx, &yaw, &pitch);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+ ASSERT_GE(answers[0]+eps, yaw);
+ ASSERT_LE(answers[0]-eps, yaw);
+ ASSERT_GE(answers[1]+eps, pitch);
+ ASSERT_LE(answers[1]-eps, pitch);
+ }
+
+ answer_idx++;
+
+ ret = mv_gaze_tracking_destroy(handle);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+ }
+
+ ret = mv_destroy_source(mv_source);
+ ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+}
\ No newline at end of file