mv_machine_learning: add data augmentation feature
authorInki Dae <inki.dae@samsung.com>
Tue, 29 Mar 2022 03:51:08 +0000 (12:51 +0900)
committerHyunsoo Park <hance.park@samsung.com>
Tue, 26 Apr 2022 03:59:19 +0000 (12:59 +0900)
[Version] 0.18.0-0
[Issue type] new feature

Added data augmentation feature which generates several input images
preprocessed in various ways to extend a given input data to mutiple
ones. With this feature, the accuracy is a little improved.
So this patch also corrects a wrong expected answer of test case.

As of now, this feature supports two preprocessing classes - horizontal
flip and rotation, and it uses default and flip classes in default.

Change-Id: I0e18e761c020ffaa8e4cca660f631413c3d5c69c
Signed-off-by: Inki Dae <inki.dae@samsung.com>
13 files changed:
mv_machine_learning/face_recognition/include/face_recognition.h
mv_machine_learning/face_recognition/src/face_recognition.cpp
mv_machine_learning/face_recognition/src/mv_face_recognition_open.cpp
mv_machine_learning/training/include/data_augment.h [new file with mode: 0644]
mv_machine_learning/training/include/data_augment_default.h [new file with mode: 0644]
mv_machine_learning/training/include/data_augment_flip.h [new file with mode: 0644]
mv_machine_learning/training/include/data_augment_rotate.h [new file with mode: 0644]
mv_machine_learning/training/src/data_augment.cpp [new file with mode: 0644]
mv_machine_learning/training/src/data_augment_default.cpp [new file with mode: 0644]
mv_machine_learning/training/src/data_augment_flip.cpp [new file with mode: 0644]
mv_machine_learning/training/src/data_augment_rotate.cpp [new file with mode: 0644]
packaging/capi-media-vision.spec
test/testsuites/machine_learning/face_recognition/test_face_recognition.cpp

index ec7fb2c426fab02de8248153ca985674489a6575..69524785957c7e5c7c347912fe2fbb3b0f0d3807 100644 (file)
@@ -24,6 +24,9 @@
 #include "label_manager.h"
 #include "face_net_info.h"
 #include "simple_shot.h"
+#include "data_augment_default.h"
+#include "data_augment_flip.h"
+#include "data_augment_rotate.h"
 
 typedef struct {
        std::string backbone_backend_name;
@@ -39,6 +42,7 @@ typedef struct {
 class FaceRecognition {
 private:
        FaceRecognitionConfig _config;
+       std::vector<std::shared_ptr<DataAugment>> _data_augments;
 
        void CheckFeatureVectorFile(std::unique_ptr<FeatureVectorManager>& old_fvm, std::unique_ptr<FeatureVectorManager>& new_fvm);
        std::unique_ptr<DataSetManager> CreateDSM(const training_engine_backend_type_e backend_type);
@@ -71,6 +75,7 @@ public:
        std::unique_ptr<InferenceEngineHelper>& GetInternal() { return _internal; }
        std::unique_ptr<InferenceEngineHelper>& GetBackbone() { return _backbone; }
        std::vector<model_layer_info>& GetBackboneInputLayerInfo() { return _face_net_info->GetInputLayerInfo(); }
+       std::vector<std::shared_ptr<DataAugment>>& GetDataAugment() { return _data_augments; }
 };
 
 #endif
\ No newline at end of file
index d7b461fab9da514edd484e5697e4335724a530d7..c412e2b658bb35d217d9ceb5598253443a5880d2 100644 (file)
@@ -40,7 +40,9 @@ using namespace Mediavision::MachineLearning::Exception;
 FaceRecognition::FaceRecognition() :
                _initialized(true), _prepared(false), _internal(), _backbone(), _face_net_info(), _training_model(), _label_manager()
 {
-
+       _data_augments.push_back(std::make_shared<DataAugmentDefault>());
+       _data_augments.push_back(std::make_shared<DataAugmentFlip>());
+       /* Add other data argument classes. */
 }
 
 FaceRecognition::~FaceRecognition()
index 222fa5fbcf3e0f79949ebcbc4d5c6ec819d7d7fd..537e8e9de6359bdc122c94e5959417054afcdfb9 100644 (file)
 
 #include <algorithm>
 #include <dlog.h>
+#include <memory>
 
 #include "face_recognition.h"
 #include "feature_vector_manager.h"
 #include "backbone_model_info.h"
 #include "mv_face_recognition_open.h"
+#include "machine_learning_exception.h"
 
 using namespace std;
+using namespace Mediavision::MachineLearning::Exception;
 
 int mv_face_recognition_create_open(mv_face_recognition_h *handle)
 {
@@ -142,12 +145,6 @@ int mv_face_recognition_prepare_open(mv_face_recognition_h handle)
                return ret;
        }
 
-       FaceRecognition *pFace = static_cast<FaceRecognition *>(handle);
-
-       ret = pFace->Prepare();
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               LOGE("Fail to prepare face recognition.");
-
        LOGD("LEAVE");
 
        return ret;
@@ -162,7 +159,6 @@ int mv_face_recognition_register_open(mv_face_recognition_h handle, mv_source_h
                return MEDIA_VISION_ERROR_INVALID_PARAMETER;
        }
 
-       FaceRecognition *pFace = static_cast<FaceRecognition *>(handle);
        mv_colorspace_e colorspace = MEDIA_VISION_COLORSPACE_INVALID;
        unsigned int width = 0, height = 0, bufferSize = 0;
        unsigned char *buffer = NULL;
@@ -180,19 +176,40 @@ int mv_face_recognition_register_open(mv_face_recognition_h handle, mv_source_h
                return MEDIA_VISION_ERROR_NOT_SUPPORTED_FORMAT;
        }
 
-       vector<float> src_vec;
-       vector<model_layer_info>& input_layer_info = pFace->GetBackboneInputLayerInfo();
-       // TODO. consider mutiple tensor info.
-       unsigned int re_width = input_layer_info[0].tensor_info.shape[0];
-       unsigned int re_height = input_layer_info[0].tensor_info.shape[1];
-
-       LOGD("Convert mv source(WxH) : %d x %d => %d x %d", width, height, re_width, re_height);
-
-       FeatureVectorManager::GetVecFromRGB(buffer, src_vec, width, height, re_width, re_height);
+       FaceRecognition *pFace = static_cast<FaceRecognition *>(handle);
 
-       int ret = pFace->RegisterNewFace(src_vec, string(label));
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               LOGE("Fail to register new face.");
+       auto data_augments = pFace->GetDataAugment();
+       int ret = MEDIA_VISION_ERROR_NONE;
+
+       for (auto& data_augment : data_augments) {
+               ret = pFace->Prepare();
+               if (ret != MEDIA_VISION_ERROR_NONE) {
+                       LOGE("Fail to prepare face recognition.");
+                       break;
+               }
+
+               vector<model_layer_info>& input_layer_info = pFace->GetBackboneInputLayerInfo();
+               // TODO. consider mutiple tensor info.
+               unsigned int re_width = input_layer_info[0].tensor_info.shape[0];
+               unsigned int re_height = input_layer_info[0].tensor_info.shape[1];
+               vector<float> src_vec;
+
+               try {
+                       data_augment->Preprocess(buffer, src_vec, width, height, re_width, re_height);
+               } catch (InvalidParameter& e) {
+                       LOGE("%s", e.what());
+                       ret = e.getError();
+                       break;
+               }
+
+               ret = pFace->RegisterNewFace(src_vec, string(label));
+               if (ret != MEDIA_VISION_ERROR_NONE) {
+                       LOGE("Fail to register new face.");
+                       break;
+               }
+
+               src_vec.clear();
+       }
 
        LOGD("LEAVE");
 
diff --git a/mv_machine_learning/training/include/data_augment.h b/mv_machine_learning/training/include/data_augment.h
new file mode 100644 (file)
index 0000000..b941f5e
--- /dev/null
@@ -0,0 +1,46 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __DATA_AUGMENT_H__
+#define __DATA_AUGMENT_H__
+
+#include <iostream>
+#include <vector>
+
+#include <opencv2/opencv.hpp>
+#include <opencv2/imgproc/imgproc.hpp>
+
+enum {
+       DATA_AUGMENT_FLIP = 0,
+       DATA_AUGMENT_ROTATE,
+};
+
+class DataAugment {
+protected:
+       unsigned int _type;
+       unsigned int _degree;
+
+public:
+       DataAugment();
+       DataAugment(unsigned int degree);
+       virtual ~DataAugment();
+
+       void Resize(cv::Mat& src, std::vector<float>& vec, int width, int height);
+       virtual void Preprocess(unsigned char *in_data, std::vector<float>& out_vec,
+                                       int width, int height, int re_width, int re_height);
+};
+
+#endif
\ No newline at end of file
diff --git a/mv_machine_learning/training/include/data_augment_default.h b/mv_machine_learning/training/include/data_augment_default.h
new file mode 100644 (file)
index 0000000..04e3c17
--- /dev/null
@@ -0,0 +1,34 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __DATA_AUGMENT_DEFAULT_H__
+#define __DATA_AUGMENT_DEFAULT_H__
+
+#include <iostream>
+#include <vector>
+
+#include "data_augment.h"
+
+class DataAugmentDefault : public DataAugment {
+public:
+       DataAugmentDefault();
+       ~DataAugmentDefault();
+
+       void Preprocess(unsigned char *in_data, std::vector<float>& out_vec,
+                                       int width, int height, int re_width, int re_height) final;
+};
+
+#endif
\ No newline at end of file
diff --git a/mv_machine_learning/training/include/data_augment_flip.h b/mv_machine_learning/training/include/data_augment_flip.h
new file mode 100644 (file)
index 0000000..69c115e
--- /dev/null
@@ -0,0 +1,34 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __DATA_AUGMENT_FLIP_H__
+#define __DATA_AUGMENT_FLIP_H__
+
+#include <iostream>
+#include <vector>
+
+#include "data_augment.h"
+
+class DataAugmentFlip : public DataAugment {
+public:
+       DataAugmentFlip();
+       ~DataAugmentFlip();
+
+       void Preprocess(unsigned char *in_data, std::vector<float>& out_vec,
+                                       int width, int height, int re_width, int re_height) final;
+};
+
+#endif
\ No newline at end of file
diff --git a/mv_machine_learning/training/include/data_augment_rotate.h b/mv_machine_learning/training/include/data_augment_rotate.h
new file mode 100644 (file)
index 0000000..428bb45
--- /dev/null
@@ -0,0 +1,34 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __DATA_AUGMENT_ROTATE_H__
+#define __DATA_AUGMENT_ROTATE_H__
+
+#include <iostream>
+#include <vector>
+
+#include "data_augment.h"
+
+class DataAugmentRotate : public DataAugment {
+public:
+       DataAugmentRotate(unsigned int degree);
+       ~DataAugmentRotate();
+
+       void Preprocess(unsigned char *in_data, std::vector<float>& out_vec,
+                                       int width, int height, int re_width, int re_height) final;
+};
+
+#endif
\ No newline at end of file
diff --git a/mv_machine_learning/training/src/data_augment.cpp b/mv_machine_learning/training/src/data_augment.cpp
new file mode 100644 (file)
index 0000000..09bc734
--- /dev/null
@@ -0,0 +1,59 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "data_augment.h"
+
+using namespace std;
+
+DataAugment::DataAugment() : _type(), _degree()
+{
+
+}
+
+DataAugment::DataAugment(unsigned int degree) : _type(), _degree(degree)
+{
+
+}
+
+DataAugment::~DataAugment()
+{
+
+}
+
+void DataAugment::Resize(cv::Mat& src, vector<float>& out_vec, int width, int height)
+{
+       cv::Mat resized;
+
+       resize(src, resized, cv::Size(width, height), 0, 0, cv::INTER_CUBIC);
+
+       cv::Mat floatSrc;
+
+       resized.convertTo(floatSrc, CV_32FC3);
+
+       cv::Mat meaned = cv::Mat(floatSrc.size(), CV_32FC3, cv::Scalar(127.5f, 127.5f, 127.5f));
+       cv::Mat dst;
+
+       cv::subtract(floatSrc, meaned, dst);
+       dst /= 127.5f;
+
+       out_vec.assign((float *)dst.data, (float *)dst.data + dst.total() * dst.channels());
+}
+
+void DataAugment::Preprocess(unsigned char *in_data, std::vector<float>& out_vec,
+                                                         int width, int height, int re_width, int re_height)
+{
+       return Preprocess(in_data, out_vec, width, height, re_width, re_height);
+}
\ No newline at end of file
diff --git a/mv_machine_learning/training/src/data_augment_default.cpp b/mv_machine_learning/training/src/data_augment_default.cpp
new file mode 100644 (file)
index 0000000..dffacc5
--- /dev/null
@@ -0,0 +1,37 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "data_augment_default.h"
+
+using namespace std;
+
+DataAugmentDefault::DataAugmentDefault() : DataAugment()
+{
+
+}
+
+DataAugmentDefault::~DataAugmentDefault()
+{
+
+}
+
+void DataAugmentDefault::Preprocess(unsigned char *in_data, vector<float>& out_vec,
+                                                                        int width, int height, int re_width, int re_height)
+{
+       cv::Mat cvSrc = cv::Mat(cv::Size(width, height), CV_MAKETYPE(CV_8U, 3), in_data).clone();
+
+       Resize(cvSrc, out_vec, re_width, re_height);
+}
\ No newline at end of file
diff --git a/mv_machine_learning/training/src/data_augment_flip.cpp b/mv_machine_learning/training/src/data_augment_flip.cpp
new file mode 100644 (file)
index 0000000..648363b
--- /dev/null
@@ -0,0 +1,41 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "data_augment_flip.h"
+
+using namespace std;
+
+DataAugmentFlip::DataAugmentFlip() : DataAugment()
+{
+
+}
+
+DataAugmentFlip::~DataAugmentFlip()
+{
+
+}
+
+void DataAugmentFlip::Preprocess(unsigned char *in_data, vector<float>& out_vec,
+                                                                 int width, int height, int re_width, int re_height)
+{
+       cv::Mat cvSrc = cv::Mat(cv::Size(width, height), CV_MAKETYPE(CV_8U, 3), in_data).clone();
+
+       cv::Mat cvFlip;
+
+       cv::flip(cvSrc, cvFlip, 1);
+
+       Resize(cvFlip, out_vec, re_width, re_height);
+}
\ No newline at end of file
diff --git a/mv_machine_learning/training/src/data_augment_rotate.cpp b/mv_machine_learning/training/src/data_augment_rotate.cpp
new file mode 100644 (file)
index 0000000..9905b48
--- /dev/null
@@ -0,0 +1,59 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "machine_learning_exception.h"
+#include "data_augment_rotate.h"
+
+using namespace std;
+using namespace Mediavision::MachineLearning::Exception;
+
+DataAugmentRotate::DataAugmentRotate(unsigned int degree) : DataAugment(degree)
+{
+
+}
+
+DataAugmentRotate::~DataAugmentRotate()
+{
+
+}
+
+void DataAugmentRotate::Preprocess(unsigned char *in_data, vector<float>& out_vec,
+                                                                 int width, int height, int re_width, int re_height)
+{
+       cv::Mat cvSrc = cv::Mat(cv::Size(width, height), CV_MAKETYPE(CV_8U, 3), in_data).clone();
+
+       cv::Mat cvRotate;
+       int rotate_code = 0;
+
+       switch (_degree) {
+       case 90:
+               rotate_code = cv::ROTATE_90_CLOCKWISE;
+               break;
+       case -90:
+       case 270:
+               rotate_code = cv::ROTATE_90_COUNTERCLOCKWISE;
+               break;
+       case 180:
+               rotate_code = cv::ROTATE_180;
+               break;
+       default:
+               throw InvalidParameter("Invalid degree value.");
+       }
+
+       cv::rotate(cvSrc, cvRotate, rotate_code);
+
+       Resize(cvRotate, out_vec, re_width, re_height);
+}
index f80da2c093aac6abe38cd133ab39b2058ca93324..9cc6c0d948f649236e30f2e2b74fd13c68311fb7 100644 (file)
@@ -1,6 +1,6 @@
 Name:        capi-media-vision
 Summary:     Media Vision library for Tizen Native API
-Version:     0.17.3
+Version:     0.18.0
 Release:     0
 Group:       Multimedia/Framework
 License:     Apache-2.0 and BSD-3-Clause
index c3c83922f3b30f732b88ff512f745521f5b4a30b..5e4654480b72833e8787162efa62fae555397998 100644 (file)
@@ -133,7 +133,7 @@ TEST(FaceRecognitionTest, FaceRecognitionClassWithEachLabelRemovalShouldBeOk)
                        { "7779", "7779", "2929", "2929", "7779",
                        "2929", "7779", "2929", "2929", "7779",
                        "2929", "7779", "7779", "7779", "7779" },
-                       { "7779", "3448", "none", "none", "3448",
+                       { "3448", "3448", "none", "none", "3448",
                        "3448", "7779", "none", "none", "3448",
                        "none", "7779", "7779", "7779", "7779" },
                        { "3448", "3448", "2929", "2929", "3448",