mv_machine_learning: fix a bug after unregistering
authorInki Dae <inki.dae@samsung.com>
Mon, 10 Jul 2023 06:21:55 +0000 (15:21 +0900)
committerKwanghoon Son <k.son@samsung.com>
Wed, 12 Jul 2023 09:45:35 +0000 (18:45 +0900)
[Issue type] : bug fix

Fix a bug that mv_face_recognition_get_confidence() call fails after
unregistering a given label from the label file.

This issue was that internal result structure can be updated only if
one of tensor value must meet the decision threshold. Therefore,
this patch makes the internal result structure to updated regardless
of the decision threshold, updates the test case, and cleans up code.

Change-Id: I0aa9d70e0b89616ee8390a1ef37986f98f6c0020
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/face_recognition/include/face_recognition.h
mv_machine_learning/face_recognition/src/face_recognition.cpp
test/testsuites/machine_learning/face_recognition/test_face_recognition.cpp

index 7f2e6fd8f5bdfe1eed5a891315ac3d3a94e004e3..b09e6778992c958c74914c17a33a2036603b099d 100644 (file)
@@ -93,7 +93,7 @@ private:
        void importLabel();
        void checkFeatureVectorFile(std::string fv_file_name, std::string new_fv_file_name);
        void storeDataSet(std::unique_ptr<DataSetManager> &data_set, unsigned int label_cnt);
-       int getAnswer();
+       void checkResult();
 
 public:
        FaceRecognition() = default;
index 946748f5b4082ad4e9c3da09f7b36f9d5619458c..999dab8513698dde72c9c217a0befbc34b904cf3 100644 (file)
@@ -195,44 +195,22 @@ int FaceRecognition::registerNewFace(std::vector<float> &input_vec, string label
        return MEDIA_VISION_ERROR_NONE;
 }
 
-int FaceRecognition::getAnswer()
+void FaceRecognition::checkResult()
 {
-       int answer_idx;
+       // Check decision threshold.
+       if (_result.raw_data[_result.label_idx] < _label_manager->getDecisionThreshold())
+               throw NoData("Not meet decision threshold.");
 
-       string result_str;
+       float weighted = _result.raw_data[_result.label_idx] * _label_manager->getDecisionWeight();
 
-       try {
-               for (auto &r : _result.raw_data)
-                       result_str += to_string(r) + " ";
-
-               LOGD("raw data = %s", result_str.c_str());
-
-               answer_idx = max_element(_result.raw_data.begin(), _result.raw_data.end()) - _result.raw_data.begin();
-
-               // Check decision threshold.
-               if (_result.raw_data[answer_idx] < _label_manager->getDecisionThreshold()) {
-                       throw NoData("Not meet decision threshold.");
-               }
+       // Check decision weight threshold.
+       for (const auto &r : _result.raw_data) {
+               if (_result.raw_data[_result.label_idx] == r)
+                       continue;
 
-               float weighted = _result.raw_data[answer_idx] * _label_manager->getDecisionWeight();
-
-               // Check decision weight threshold.
-               for (auto &r : _result.raw_data) {
-                       if (_result.raw_data[answer_idx] == r)
-                               continue;
-
-                       if (weighted < r)
-                               throw NoData("Not meet decision weight threshold");
-               }
-
-               _result.label_idx = answer_idx;
-               _result.is_valid = true;
-       } catch (const BaseException &e) {
-               LOGE("%s", e.what());
-               return e.getError();
+               if (weighted < r)
+                       throw NoData("Not meet decision weight threshold");
        }
-
-       return MEDIA_VISION_ERROR_NONE;
 }
 
 int FaceRecognition::recognizeFace(std::vector<float> &input_vec)
@@ -302,20 +280,37 @@ int FaceRecognition::recognizeFace(std::vector<float> &input_vec)
 
                auto raw_buffer = static_cast<float *>(internal_output_buffer->buffer);
 
+               // Update the result because user can want to get the result regardless of the decision threshold
+               // to check how many people are registered to the label file.
+               //
+               // TODO. as for this, we may need to introduce a new API to provide the number of people later.
                _result.raw_data.clear();
                copy(raw_buffer, raw_buffer + internal_output_buffer->size / sizeof(float), back_inserter(_result.raw_data));
 
+               string result_str;
+
+               for (const auto &r : _result.raw_data)
+                       result_str += to_string(r) + " ";
+
+               LOGD("raw data = %s", result_str.c_str());
+
                _result.labels.clear();
                _result.labels = _label_manager->getLabels();
+
+               unsigned int answer_index =
+                               max_element(_result.raw_data.begin(), _result.raw_data.end()) - _result.raw_data.begin();
+               _result.label_idx = answer_index;
+               _result.is_valid = true;
+
                _status = WorkingStatus::INFERENCED;
 
-               return getAnswer();
+               checkResult();
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
        }
 
-       return MEDIA_VISION_ERROR_INVALID_OPERATION;
+       return MEDIA_VISION_ERROR_NONE;
 }
 
 int FaceRecognition::deleteLabel(string label_name)
index f49462d554bae8b4157b5408dbcb02b6958a8ac6..d4102c2549effd991518b42c1fc87dfa84df4a73 100644 (file)
@@ -297,6 +297,39 @@ TEST(FaceRecognitionTest, LabelUpdateAfterInferenceShouldBeOk)
                ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
        }
 
+       // Remove "2929" label from the label file.
+       ret = mv_face_recognition_unregister(handle, "2929");
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       ret = mv_face_recognition_get_confidence(handle, &confidences, &num_of_confidences);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       // num_of_confidences must be 3 yet because of no inference request.
+       ASSERT_EQ(num_of_confidences, 3);
+
+       const string image_path = string(TRAINING_IMAGE_PATH) + image_names[0];
+       mv_source_h mv_source = NULL;
+
+       ret = mv_create_source(&mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       ret = ImageHelper::loadImageToSource(image_path.c_str(), mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       ret = mv_face_recognition_inference(handle, mv_source);
+       if (ret != MEDIA_VISION_ERROR_NO_DATA) {
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+       }
+
+       ret = mv_destroy_source(mv_source);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       ret = mv_face_recognition_get_confidence(handle, &confidences, &num_of_confidences);
+       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+       // num_of_confidence must be 2 now.
+       ASSERT_EQ(num_of_confidences, 2);
+
        ret = mv_face_recognition_destroy(handle);
        ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);