--- /dev/null
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __G_GAZE_ESTIMATION_H__
+#define __G_GAZE_ESTIMATION_H__
+
+#include "mv_private.h"
+#include <memory>
+#include <mv_common.h>
+#include <string>
+
+#include "GazeTracking.h"
+#include <mv_inference_type.h>
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T> class GGazeEstimation : public GazeTracking<T>
+{
+ using GazeTracking<T>::_config;
+ using GazeTracking<T>::_preprocess;
+
+private:
+ GazeTrackingResult _result;
+
+public:
+ GGazeEstimation(GazeTrackingTaskType task_type, std::shared_ptr<Config> config);
+ ~GGazeEstimation();
+
+ GazeTrackingResult &result() override;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
#include "EngineConfig.h"
#include "ITask.h"
-#include "L2CSNet.h"
#include "MvMlConfig.h"
-
+#include "IGazeTracking.h"
namespace mediavision
{
namespace machine_learning
enum class GazeTrackingTaskType {
GAZE_TRACKINGION_TASK_NONE = 0,
L2CS_NET,
+ G_GAZE_ESTIMATION
// TODO
};
{
"name" : "MODEL_FILE_NAME",
"type" : "string",
- "value" : "l2cs_net_1x3x448x448_float32.tflite"
+ "value" : "generalizing_gaze_estimation_with_weak_supervision_from_synthetic_views_160x160_float16.tflite"
},
{
"name" : "DEFAULT_MODEL_NAME",
"type" : "string",
- "value" : "L2CS_NET"
+ "value" : "G_GAZE_ESTIMATION"
},
{
"name" : "MODEL_META_FILE_NAME",
"type" : "string",
- "value" : "l2cs_net_1x3x448x448_float32.json"
+ "value" : "generalizing_gaze_estimation_with_weak_supervision_from_synthetic_views_160x160_float16.json"
},
{
"name" : "BACKEND_TYPE",
{
"name" : "DEFAULT_MODEL_NAME",
"type" : "string",
- "value" : "L2CS_NET"
+ "value" : "G_GAZE_ESTIMATION"
},
{
"name" : "USE_PLUGIN",
--- /dev/null
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <map>
+#include <string.h>
+
+#include "GGazeEstimation.h"
+#include "MvMlException.h"
+#include "Postprocess.h"
+#include "mv_gaze_tracking_config.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T>
+GGazeEstimation<T>::GGazeEstimation(GazeTrackingTaskType task_type, std::shared_ptr<Config> config)
+ : GazeTracking<T>(task_type, config), _result()
+{}
+
+template<typename T> GGazeEstimation<T>::~GGazeEstimation()
+{}
+
+template<typename T> GazeTrackingResult &GGazeEstimation<T>::result()
+{
+ // Clear _result object because result() function can be called every time user wants
+ // so make sure to clear existing result data before getting the data again.
+ _result = GazeTrackingResult();
+
+ vector<string> names;
+
+ GazeTracking<T>::getOutputNames(names);
+
+ vector<float> outputTensor;
+
+ GazeTracking<T>::getOutputTensor(names[0], outputTensor);
+
+ LOGD("GGazeEstimation::result() - outputTensor size: %zu", outputTensor.size());
+
+ _result.frame_number++;
+ return _result;
+}
+
+template class GGazeEstimation<float>;
+template class GGazeEstimation<unsigned char>;
+
+}
+}
#include "MvMlException.h"
#include "gaze_tracking_type.h"
#include "mv_gaze_tracking_config.h"
+#include "L2CSNet.h"
+#include "GGazeEstimation.h"
using namespace std;
using namespace MediaVision::Common;
case GazeTrackingTaskType::L2CS_NET:
_gaze_tracking = make_unique<L2CSNet<U> >(task_type, _config);
break;
+ case GazeTrackingTaskType::G_GAZE_ESTIMATION:
+ _gaze_tracking = make_unique<GGazeEstimation<U> >(task_type, _config);
+ break;
default:
throw InvalidOperation("Invalid gaze tracking task type.");
}
if (model_name == "L2CS_NET")
return GazeTrackingTaskType::L2CS_NET;
+ if (model_name == "G_GAZE_ESTIMATION")
+ return GazeTrackingTaskType::G_GAZE_ESTIMATION;
// TODO.
throw InvalidParameter("Invalid gaze tracking model name.");
#define TASK_NAME "gaze_tracking"
using namespace std;
-using namespace mediavision::inference;
using namespace mediavision::common;
using namespace mediavision::machine_learning;
using namespace MediaVision::Common;