mv_machine_learning: add G_GAZE_ESTIMATION model support
authorInki Dae <inki.dae@samsung.com>
Tue, 24 Sep 2024 06:35:33 +0000 (15:35 +0900)
committerInki Dae <inki.dae@samsung.com>
Thu, 7 Nov 2024 05:26:11 +0000 (14:26 +0900)
Change-Id: I6202e243f0542911a952a7598b1eb76ee2a2f22d
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/gaze_tracking/include/GGazeEstimation.h [new file with mode: 0644]
mv_machine_learning/gaze_tracking/include/GazeTrackingAdapter.h
mv_machine_learning/gaze_tracking/include/gaze_tracking_type.h
mv_machine_learning/gaze_tracking/meta/gaze_tracking.json
mv_machine_learning/gaze_tracking/meta/gaze_tracking_plugin.json
mv_machine_learning/gaze_tracking/src/GGazeEstimation.cpp [new file with mode: 0644]
mv_machine_learning/gaze_tracking/src/GazeTrackingAdapter.cpp
mv_machine_learning/gaze_tracking/src/mv_gaze_tracking.cpp

diff --git a/mv_machine_learning/gaze_tracking/include/GGazeEstimation.h b/mv_machine_learning/gaze_tracking/include/GGazeEstimation.h
new file mode 100644 (file)
index 0000000..9cbfa47
--- /dev/null
@@ -0,0 +1,50 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __G_GAZE_ESTIMATION_H__
+#define __G_GAZE_ESTIMATION_H__
+
+#include "mv_private.h"
+#include <memory>
+#include <mv_common.h>
+#include <string>
+
+#include "GazeTracking.h"
+#include <mv_inference_type.h>
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T> class GGazeEstimation : public GazeTracking<T>
+{
+       using GazeTracking<T>::_config;
+       using GazeTracking<T>::_preprocess;
+
+private:
+       GazeTrackingResult _result;
+
+public:
+       GGazeEstimation(GazeTrackingTaskType task_type, std::shared_ptr<Config> config);
+       ~GGazeEstimation();
+
+       GazeTrackingResult &result() override;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index 05389a5f064c2decb09413d02c6a4d1985007bfc..136d47414f35e08482a090568ed9071eba33b00c 100644 (file)
@@ -21,9 +21,8 @@
 
 #include "EngineConfig.h"
 #include "ITask.h"
-#include "L2CSNet.h"
 #include "MvMlConfig.h"
-
+#include "IGazeTracking.h"
 namespace mediavision
 {
 namespace machine_learning
index 020f479808f193377546b406da0df06f437a4aa5..ea462f12e5a9365697527aaa8bd07b49945b8b9a 100644 (file)
@@ -49,6 +49,7 @@ struct GazeTrackingResult : public OutputBaseType {
 enum class GazeTrackingTaskType {
        GAZE_TRACKINGION_TASK_NONE = 0,
        L2CS_NET,
+       G_GAZE_ESTIMATION
        // TODO
 };
 
index aa81ae0da772b76889fba469e008fc5971c91dc3..d42ec2fe1617ba608e20d0de99ce9cfd803e0891 100644 (file)
@@ -9,17 +9,17 @@
                {
             "name"  : "MODEL_FILE_NAME",
             "type"  : "string",
-            "value" : "l2cs_net_1x3x448x448_float32.tflite"
+            "value" : "generalizing_gaze_estimation_with_weak_supervision_from_synthetic_views_160x160_float16.tflite"
         },
         {
             "name"  : "DEFAULT_MODEL_NAME",
             "type"  : "string",
-            "value" : "L2CS_NET"
+            "value" : "G_GAZE_ESTIMATION"
         },
         {
             "name"  : "MODEL_META_FILE_NAME",
             "type"  : "string",
-            "value" : "l2cs_net_1x3x448x448_float32.json"
+            "value" : "generalizing_gaze_estimation_with_weak_supervision_from_synthetic_views_160x160_float16.json"
         },
         {
             "name"  : "BACKEND_TYPE",
index bc56eee6d97cc3e3a2a696e3d69cacda87de9c97..2b113ad3a75d0df86ec14f32307c9dbe375c88af 100644 (file)
@@ -9,7 +9,7 @@
         {
             "name"  : "DEFAULT_MODEL_NAME",
             "type"  : "string",
-            "value" : "L2CS_NET"
+            "value" : "G_GAZE_ESTIMATION"
         },
         {
             "name"  : "USE_PLUGIN",
diff --git a/mv_machine_learning/gaze_tracking/src/GGazeEstimation.cpp b/mv_machine_learning/gaze_tracking/src/GGazeEstimation.cpp
new file mode 100644 (file)
index 0000000..3b9a878
--- /dev/null
@@ -0,0 +1,67 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <map>
+#include <string.h>
+
+#include "GGazeEstimation.h"
+#include "MvMlException.h"
+#include "Postprocess.h"
+#include "mv_gaze_tracking_config.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T>
+GGazeEstimation<T>::GGazeEstimation(GazeTrackingTaskType task_type, std::shared_ptr<Config> config)
+               : GazeTracking<T>(task_type, config), _result()
+{}
+
+template<typename T> GGazeEstimation<T>::~GGazeEstimation()
+{}
+
+template<typename T> GazeTrackingResult &GGazeEstimation<T>::result()
+{
+       // Clear _result object because result() function can be called every time user wants
+       // so make sure to clear existing result data before getting the data again.
+       _result = GazeTrackingResult();
+
+       vector<string> names;
+
+       GazeTracking<T>::getOutputNames(names);
+
+       vector<float> outputTensor;
+
+       GazeTracking<T>::getOutputTensor(names[0], outputTensor);
+
+       LOGD("GGazeEstimation::result() - outputTensor size: %zu", outputTensor.size());
+
+       _result.frame_number++;
+       return _result;
+}
+
+template class GGazeEstimation<float>;
+template class GGazeEstimation<unsigned char>;
+
+}
+}
index 6c95f5d6a1f912431ffcccfd1211b5fbee44d363..775a6ab81d56dcce4fab1fb87afa1724aa9de32d 100644 (file)
@@ -18,6 +18,8 @@
 #include "MvMlException.h"
 #include "gaze_tracking_type.h"
 #include "mv_gaze_tracking_config.h"
+#include "L2CSNet.h"
+#include "GGazeEstimation.h"
 
 using namespace std;
 using namespace MediaVision::Common;
@@ -52,6 +54,9 @@ template<typename U> void GazeTrackingAdapter::create(GazeTrackingTaskType task_
        case GazeTrackingTaskType::L2CS_NET:
                _gaze_tracking = make_unique<L2CSNet<U> >(task_type, _config);
                break;
+       case GazeTrackingTaskType::G_GAZE_ESTIMATION:
+               _gaze_tracking = make_unique<GGazeEstimation<U> >(task_type, _config);
+               break;
        default:
                throw InvalidOperation("Invalid gaze tracking task type.");
        }
@@ -86,6 +91,8 @@ GazeTrackingTaskType GazeTrackingAdapter::convertToTaskType(string model_name)
 
        if (model_name == "L2CS_NET")
                return GazeTrackingTaskType::L2CS_NET;
+       if (model_name == "G_GAZE_ESTIMATION")
+               return GazeTrackingTaskType::G_GAZE_ESTIMATION;
        // TODO.
 
        throw InvalidParameter("Invalid gaze tracking model name.");
index 804d62ead379417772a8273ed1c856d71d0ffd62..e0f503eef723c659e15734eb1f394035c022dd2b 100644 (file)
@@ -36,7 +36,6 @@
 #define TASK_NAME "gaze_tracking"
 
 using namespace std;
-using namespace mediavision::inference;
 using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;