Add U2Net based face landmark detection
authorTae-Young Chung <ty83.chung@samsung.com>
Thu, 25 Apr 2024 05:13:52 +0000 (14:13 +0900)
committerKwanghoon Son <k.son@samsung.com>
Mon, 24 Jun 2024 03:53:24 +0000 (03:53 +0000)
This patch supports U2Net based face landmark model,
which provides 68 points.
The model is provided from https://github.sec.samsung.net/tizen-vault/open_model_zoo/pull/9

Change-Id: I30cd9ef2d173c9cc05f43c579cc3d2f589bc120a
Signed-off-by: Tae-Young Chung <ty83.chung@samsung.com>
mv_machine_learning/landmark_detection/include/FldU2net.h [new file with mode: 0644]
mv_machine_learning/landmark_detection/include/landmark_detection_type.h
mv_machine_learning/landmark_detection/src/FacialLandmarkAdapter.cpp
mv_machine_learning/landmark_detection/src/FldU2net.cpp [new file with mode: 0644]

diff --git a/mv_machine_learning/landmark_detection/include/FldU2net.h b/mv_machine_learning/landmark_detection/include/FldU2net.h
new file mode 100644 (file)
index 0000000..5f274dc
--- /dev/null
@@ -0,0 +1,50 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __FLD_U2NET_H__
+#define __FLD_U2NET_H__
+
+#include <memory>
+#include <mv_common.h>
+
+#include "LandmarkDetection.h"
+#include <mv_inference_type.h>
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T> class FldU2net : public LandmarkDetection<T>
+{
+       using LandmarkDetection<T>::_config;
+       using LandmarkDetection<T>::_preprocess;
+       using LandmarkDetection<T>::_inference;
+
+private:
+       unsigned int _numberOfLandmarks;
+       LandmarkDetectionResult _result;
+
+public:
+       FldU2net(LandmarkDetectionTaskType task_type, std::shared_ptr<Config> config);
+       ~FldU2net();
+
+       LandmarkDetectionResult &result() override;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index fe20397a63e462e1433172b824cfc8b4c37ba644..23d3bf38fd2c6bee5b92c95986a89b818ba0f3e1 100644 (file)
@@ -40,7 +40,7 @@ struct LandmarkDetectionResult : public OutputBaseType {
        std::vector<std::string> labels;
 };
 
-enum class LandmarkDetectionTaskType { LANDMARK_DETECTION_TASK_NONE = 0, FLD_TWEAK_CNN, PLD_CPM };
+enum class LandmarkDetectionTaskType { LANDMARK_DETECTION_TASK_NONE = 0, FLD_TWEAK_CNN, FLD_U2NET, PLD_CPM };
 
 }
 }
index a4ba0adf9ef7ce2b64f2ecb906ba89e8335c8b4a..256dc2bf011e7564bc6b55a101dcb71b9071fd78 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include "FacialLandmarkAdapter.h"
+#include "FldU2net.h"
 #include "MvMlException.h"
 #include "mv_landmark_detection_config.h"
 
@@ -46,6 +47,9 @@ template<typename U> void FacialLandmarkAdapter::create(LandmarkDetectionTaskTyp
        case LandmarkDetectionTaskType::FLD_TWEAK_CNN:
                _landmark_detection = make_unique<FldTweakCnn<U> >(task_type, _config);
                break;
+       case LandmarkDetectionTaskType::FLD_U2NET:
+               _landmark_detection = make_unique<FldU2net<U> >(task_type, _config);
+               break;
        default:
                throw InvalidOperation("Invalid landmark detection task type.");
        }
@@ -79,6 +83,8 @@ LandmarkDetectionTaskType FacialLandmarkAdapter::convertToTaskType(string model_
 
        if (model_name == "FLD_TWEAK_CNN")
                return LandmarkDetectionTaskType::FLD_TWEAK_CNN;
+       else if (model_name == "FLD_U2NET")
+               return LandmarkDetectionTaskType::FLD_U2NET;
        // TODO.
 
        throw InvalidParameter("Invalid facial detection model name.");
diff --git a/mv_machine_learning/landmark_detection/src/FldU2net.cpp b/mv_machine_learning/landmark_detection/src/FldU2net.cpp
new file mode 100644 (file)
index 0000000..5ccf95e
--- /dev/null
@@ -0,0 +1,100 @@
+/**
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <map>
+#include <string.h>
+
+#include "FldU2net.h"
+#include "MvMlException.h"
+#include "Postprocess.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+template<typename T>
+FldU2net<T>::FldU2net(LandmarkDetectionTaskType task_type, std::shared_ptr<Config> config)
+               : LandmarkDetection<T>(task_type, config), _result()
+{
+       _numberOfLandmarks = 68;
+}
+
+template<typename T> FldU2net<T>::~FldU2net()
+{}
+
+template<typename T> LandmarkDetectionResult &FldU2net<T>::result()
+{
+       // Clear _result object because result() function can be called every time user wants
+       // so make sure to clear existing result data before getting the data again.
+       _result = LandmarkDetectionResult();
+
+       vector<string> names;
+
+       LandmarkDetection<T>::getOutputNames(names);
+
+       auto scoreMetaInfo = _config->getOutputMetaMap().at(names[0]);
+       auto decodingLandmark =
+                       static_pointer_cast<DecodingLandmark>(scoreMetaInfo->decodingTypeMap[DecodingType::LANDMARK]);
+
+       if (decodingLandmark->decoding_type != LandmarkDecodingType::BYPASS)
+               throw InvalidOperation("decoding type not support.");
+
+       if (decodingLandmark->coordinate_type != LandmarkCoordinateType::RATIO)
+               throw InvalidOperation("coordinate type not support.");
+
+       if (decodingLandmark->landmark_type != LandmarkType::SINGLE_2D)
+               throw InvalidOperation("landmark type not support.");
+
+       auto ori_src_width = static_cast<double>(_preprocess.getImageWidth()[0]);
+       auto ori_src_height = static_cast<double>(_preprocess.getImageHeight()[0]);
+       auto input_tensor_width = static_cast<double>(_inference->getInputWidth());
+       auto input_tensor_height = static_cast<double>(_inference->getInputHeight());
+
+       _result.number_of_landmarks = _numberOfLandmarks;
+
+       vector<float> score_tensor;
+
+       LandmarkDetection<T>::getOutputTensor(names[0], score_tensor);
+
+       // Calculate the ratio[A] between the original image size and the input tensor size.
+       auto width_ratio = ori_src_width / input_tensor_width;
+       auto height_ratio = ori_src_height / input_tensor_height;
+
+       // In case that landmark coordinate type is RATIO, output tensor buffer contains ratio values indicating
+       // the position of each landmark for the input tensor.
+       // Therefore, each landmark position for original image is as following,
+       //    x = [width A] * width of input tensor * width ratio value of output tensor.
+       //    y = [height A] * height of input tensor * height ratio value of output tensor.
+       for (unsigned int idx = 0; idx < _numberOfLandmarks; ++idx) {
+               _result.x_pos.push_back(
+                               static_cast<unsigned int>(width_ratio * input_tensor_width * score_tensor[idx + idx * 1]));
+               _result.y_pos.push_back(
+                               static_cast<unsigned int>(height_ratio * input_tensor_height * score_tensor[idx + idx * 1 + 1]));
+       }
+
+       return _result;
+}
+
+template class FldU2net<unsigned char>;
+template class FldU2net<float>;
+
+}
+}
\ No newline at end of file