mv_machine_learning: rename gaze model name and task type for readability and consistency
authorTae-Young Chung <ty83.chung@samsung.com>
Thu, 21 Nov 2024 05:27:10 +0000 (14:27 +0900)
committerKwanghoon Son <k.son@samsung.com>
Thu, 12 Dec 2024 01:56:41 +0000 (10:56 +0900)
Change-Id: I17da9351a49214f7a47e1623197ffa25cd2f83dd
Signed-off-by: Tae-Young Chung <ty83.chung@samsung.com>
mv_machine_learning/gaze_tracking/include/GGazeEstimation.h [deleted file]
mv_machine_learning/gaze_tracking/include/GazeTracking.h
mv_machine_learning/gaze_tracking/include/GeneralGazeNet.h [new file with mode: 0644]
mv_machine_learning/gaze_tracking/include/gaze_tracking_type.h
mv_machine_learning/gaze_tracking/meta/gaze_tracking.json
mv_machine_learning/gaze_tracking/meta/gaze_tracking_plugin.json
mv_machine_learning/gaze_tracking/src/GGazeEstimation.cpp [deleted file]
mv_machine_learning/gaze_tracking/src/GazeTrackingAdapter.cpp
mv_machine_learning/gaze_tracking/src/GeneralGazeNet.cpp [new file with mode: 0644]

diff --git a/mv_machine_learning/gaze_tracking/include/GGazeEstimation.h b/mv_machine_learning/gaze_tracking/include/GGazeEstimation.h
deleted file mode 100644 (file)
index 9cbfa47..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-/**
- * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __G_GAZE_ESTIMATION_H__
-#define __G_GAZE_ESTIMATION_H__
-
-#include "mv_private.h"
-#include <memory>
-#include <mv_common.h>
-#include <string>
-
-#include "GazeTracking.h"
-#include <mv_inference_type.h>
-
-namespace mediavision
-{
-namespace machine_learning
-{
-template<typename T> class GGazeEstimation : public GazeTracking<T>
-{
-       using GazeTracking<T>::_config;
-       using GazeTracking<T>::_preprocess;
-
-private:
-       GazeTrackingResult _result;
-
-public:
-       GGazeEstimation(GazeTrackingTaskType task_type, std::shared_ptr<Config> config);
-       ~GGazeEstimation();
-
-       GazeTrackingResult &result() override;
-};
-
-} // machine_learning
-} // mediavision
-
-#endif
\ No newline at end of file
index 5f6c3c1bc64bdb43339117437e4aeaa7040be575..ae5d66ebcf7211617e6e489b58dd1ef664d8b720 100644 (file)
@@ -44,7 +44,7 @@ namespace machine_learning
 template<typename T> class GazeTracking : public IGazeTracking
 {
 private:
-       GazeTrackingTaskType _task_type { GazeTrackingTaskType::GAZE_TRACKINGION_TASK_NONE };
+       GazeTrackingTaskType _task_type { GazeTrackingTaskType::GZE_TRACKING_TASK_NONE };
        std::unique_ptr<AsyncManager<T, GazeTrackingResult> > _async_manager;
        GazeTrackingResult _current_result;
 
diff --git a/mv_machine_learning/gaze_tracking/include/GeneralGazeNet.h b/mv_machine_learning/gaze_tracking/include/GeneralGazeNet.h
new file mode 100644 (file)
index 0000000..7db4186
--- /dev/null
@@ -0,0 +1,50 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GENERAL_GAZE_NET_H__
+#define __GENERAL_GAZE_NET_H__
+
+#include "mv_private.h"
+#include <memory>
+#include <mv_common.h>
+#include <string>
+
+#include "GazeTracking.h"
+#include <mv_inference_type.h>
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T> class GeneralGazeNet : public GazeTracking<T>
+{
+       using GazeTracking<T>::_config;
+       using GazeTracking<T>::_preprocess;
+
+private:
+       GazeTrackingResult _result;
+
+public:
+       GeneralGazeNet(GazeTrackingTaskType task_type, std::shared_ptr<Config> config);
+       ~GeneralGazeNet();
+
+       GazeTrackingResult &result() override;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index 0f26953a31e36f47019030a9804bad48a5b017d9..30b442346d1ce5d5b64d8995a8aa89b46f274dc0 100644 (file)
@@ -47,10 +47,10 @@ struct GazeTrackingResult : public OutputBaseType {
 };
 
 enum class GazeTrackingTaskType {
-       GAZE_TRACKINGION_TASK_NONE = 0,
-       L2CS_NET,
-       G_GAZE_ESTIMATION,
-       TINY_TRACKER
+       GZE_TRACKING_TASK_NONE = 0,
+       GZE_L2CS_NET,
+       GZE_GENERAL_GAZE_NET,
+       GZE_TINY_TRACKER
        // TODO
 };
 
index 4236faf59f8e951620c276ebff8d6e27a4ce6558..bf152625b5f3e6589663537cfeda7801e833399b 100644 (file)
@@ -14,7 +14,7 @@
         {
             "name"  : "DEFAULT_MODEL_NAME",
             "type"  : "string",
-            "value" : "TINY_TRACKER"
+            "value" : "GZE_TINY_TRACKER"
         },
         {
             "name"  : "MODEL_META_FILE_NAME",
index 2b113ad3a75d0df86ec14f32307c9dbe375c88af..c803c90a3d1e4e01265ad42ba54cfe65027323f5 100644 (file)
@@ -9,7 +9,7 @@
         {
             "name"  : "DEFAULT_MODEL_NAME",
             "type"  : "string",
-            "value" : "G_GAZE_ESTIMATION"
+            "value" : "GZE_GENERAL_GAZE_NET"
         },
         {
             "name"  : "USE_PLUGIN",
diff --git a/mv_machine_learning/gaze_tracking/src/GGazeEstimation.cpp b/mv_machine_learning/gaze_tracking/src/GGazeEstimation.cpp
deleted file mode 100644 (file)
index 94a13aa..0000000
+++ /dev/null
@@ -1,68 +0,0 @@
-/**
- * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <algorithm>
-#include <cmath>
-#include <map>
-#include <string.h>
-
-#include "GGazeEstimation.h"
-#include "MvMlException.h"
-#include "Postprocess.h"
-#include "mv_gaze_tracking_config.h"
-
-using namespace std;
-using namespace mediavision::inference;
-using namespace mediavision::machine_learning::exception;
-
-namespace mediavision
-{
-namespace machine_learning
-{
-template<typename T>
-GGazeEstimation<T>::GGazeEstimation(GazeTrackingTaskType task_type, std::shared_ptr<Config> config)
-               : GazeTracking<T>(task_type, config), _result()
-{}
-
-template<typename T> GGazeEstimation<T>::~GGazeEstimation()
-{}
-
-template<typename T> GazeTrackingResult &GGazeEstimation<T>::result()
-{
-       // Clear _result object because result() function can be called every time user wants
-       // so make sure to clear existing result data before getting the data again.
-       _result = GazeTrackingResult();
-
-       vector<string> names;
-
-       GazeTracking<T>::getOutputNames(names);
-
-       vector<float> outputTensor;
-
-       GazeTracking<T>::getOutputTensor(names[0], outputTensor);
-
-       LOGD("GGazeEstimation::result() - outputTensor size: %zu", outputTensor.size());
-
-       _result.frame_number++;
-       return _result;
-}
-
-template class GGazeEstimation<float>;
-template class GGazeEstimation<unsigned char>;
-template class GGazeEstimation<char>;
-
-}
-}
index 89c64e2c3465b942e085eb333e06f61084f2726e..4b553ce768101a100742bc0f3252f8efeea2d71d 100644 (file)
@@ -15,7 +15,7 @@
  */
 
 #include "GazeTrackingAdapter.h"
-#include "GGazeEstimation.h"
+#include "GeneralGazeNet.h"
 #include "L2CSNet.h"
 #include "MvMlException.h"
 #include "TinyTracker.h"
@@ -52,13 +52,13 @@ GazeTrackingAdapter::~GazeTrackingAdapter()
 template<typename U> void GazeTrackingAdapter::create(GazeTrackingTaskType task_type)
 {
        switch (task_type) {
-       case GazeTrackingTaskType::L2CS_NET:
+       case GazeTrackingTaskType::GZE_L2CS_NET:
                _gaze_tracking = make_unique<L2CSNet<U> >(task_type, _config);
                break;
-       case GazeTrackingTaskType::G_GAZE_ESTIMATION:
-               _gaze_tracking = make_unique<GGazeEstimation<U> >(task_type, _config);
+       case GazeTrackingTaskType::GZE_GENERAL_GAZE_NET:
+               _gaze_tracking = make_unique<GeneralGazeNet<U> >(task_type, _config);
                break;
-       case GazeTrackingTaskType::TINY_TRACKER:
+       case GazeTrackingTaskType::GZE_TINY_TRACKER:
                _gaze_tracking = make_unique<TinyTracker<U> >(task_type, _config);
                break;
        default:
@@ -96,12 +96,12 @@ GazeTrackingTaskType GazeTrackingAdapter::convertToTaskType(string model_name)
 
        transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
 
-       if (model_name == "L2CS_NET")
-               return GazeTrackingTaskType::L2CS_NET;
-       if (model_name == "G_GAZE_ESTIMATION")
-               return GazeTrackingTaskType::G_GAZE_ESTIMATION;
-       if (model_name == "TINY_TRACKER")
-               return GazeTrackingTaskType::TINY_TRACKER;
+       if (model_name == "GZE_L2CS_NET")
+               return GazeTrackingTaskType::GZE_L2CS_NET;
+       if (model_name == "GZE_GENERAL_GAZE_NET")
+               return GazeTrackingTaskType::GZE_GENERAL_GAZE_NET;
+       if (model_name == "GZE_TINY_TRACKER")
+               return GazeTrackingTaskType::GZE_TINY_TRACKER;
        // TODO.
 
        throw InvalidParameter("Invalid gaze tracking model name.");
diff --git a/mv_machine_learning/gaze_tracking/src/GeneralGazeNet.cpp b/mv_machine_learning/gaze_tracking/src/GeneralGazeNet.cpp
new file mode 100644 (file)
index 0000000..426b1d7
--- /dev/null
@@ -0,0 +1,68 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <map>
+#include <string.h>
+
+#include "GeneralGazeNet.h"
+#include "MvMlException.h"
+#include "Postprocess.h"
+#include "mv_gaze_tracking_config.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T>
+GeneralGazeNet<T>::GeneralGazeNet(GazeTrackingTaskType task_type, std::shared_ptr<Config> config)
+               : GazeTracking<T>(task_type, config), _result()
+{}
+
+template<typename T> GeneralGazeNet<T>::~GeneralGazeNet()
+{}
+
+template<typename T> GazeTrackingResult &GeneralGazeNet<T>::result()
+{
+       // Clear _result object because result() function can be called every time user wants
+       // so make sure to clear existing result data before getting the data again.
+       _result = GazeTrackingResult();
+
+       vector<string> names;
+
+       GazeTracking<T>::getOutputNames(names);
+
+       vector<float> outputTensor;
+
+       GazeTracking<T>::getOutputTensor(names[0], outputTensor);
+
+       LOGD("GGazeEstimation::result() - outputTensor size: %zu", outputTensor.size());
+
+       _result.frame_number++;
+       return _result;
+}
+
+template class GeneralGazeNet<float>;
+template class GeneralGazeNet<unsigned char>;
+template class GeneralGazeNet<char>;
+
+}
+}