mv_machine_learning: modify ImageClassificationResult 96/305296/2
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Wed, 31 Jan 2024 04:52:01 +0000 (13:52 +0900)
committerVibhav Aggarwal <v.aggarwal@samsung.com>
Thu, 1 Feb 2024 04:16:20 +0000 (13:16 +0900)
[Issue type] code improvement

This patch modifies the ImageClassificationResult struct
to include multiple labels, along with indices and
confidences of each label. This will be useful to
reimplement the legacy APIs using ITask.

Change-Id: Ic57fd77f10f762997b26553843bb43acefb2f114
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
mv_machine_learning/image_classification/include/image_classification_type.h
mv_machine_learning/image_classification/src/ImageClassificationDefault.cpp
mv_machine_learning/image_classification/src/mv_image_classification.cpp

index 92277d4ab73a1c18eb454b51ff6d3b21852b83b3..db663b518b0a0a05ab4841448ac4a9a56bfdf19a 100644 (file)
@@ -33,7 +33,10 @@ struct ImageClassificationInput : public InputBaseType {
 };
 
 struct ImageClassificationResult : public OutputBaseType {
-       std::string label;
+       unsigned int number_of_classes {};
+       std::vector<std::string> labels;
+       std::vector<unsigned int> indices;
+       std::vector<float> confidences;
 };
 
 enum class ImageClassificationTaskType {
index 01f8a34af641a8e5c58996fd0c2eafaa3ac722af..36e48ec26af50e12dedaa2db40cc19dae0122e70 100644 (file)
@@ -16,6 +16,7 @@
 
 #include <string.h>
 #include <map>
+#include <set>
 #include <algorithm>
 
 #include "MvMlException.h"
@@ -57,8 +58,27 @@ template<typename T> ImageClassificationResult &ImageClassificationDefault<T>::r
                        output_vec[idx] = PostProcess::sigmoid(output_vec[idx]);
        }
 
-       _result.label = _labels[max_element(output_vec.begin(), output_vec.end()) - output_vec.begin()];
-       LOGI("Label = %s", _result.label.c_str());
+       set<pair<float, int> > topScore;
+       for (unsigned int idx = 0; idx < output_vec.size(); ++idx) {
+               topScore.insert({ output_vec[idx], idx });
+
+               // Remove the smallest score if the set size exceeds topNumber
+               if (topScore.size() > decodingScore->topNumber)
+                       topScore.erase(topScore.begin());
+       }
+
+       _result.confidences.clear();
+       _result.indices.clear();
+       for (auto it = topScore.rbegin(); it != topScore.rend(); it++) {
+               _result.confidences.push_back(it->first);
+               _result.indices.push_back(it->second);
+       }
+
+       _result.number_of_classes = _result.indices.size();
+
+       _result.labels.clear();
+       for (unsigned int idx : _result.indices)
+               _result.labels.push_back(_labels[idx]);
 
        return _result;
 }
index df79d2429f5d60f5ea3a9052e4ce35a5696427d7..a5237d688fbc74e8c175520d3758db1e3fa561f0 100644 (file)
@@ -178,7 +178,7 @@ int mv_image_classification_get_result_count(mv_image_classification_h handle, u
                *frame_number = result.frame_number;
                // As of now, the result count of image classification task is always 1.
                // TODO. consider for Multilabel image classification model later.
-               *result_cnt = static_cast<unsigned int>(!result.label.empty());
+               *result_cnt = static_cast<unsigned int>(!result.labels.empty());
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
@@ -201,13 +201,13 @@ int mv_image_classification_get_result(mv_image_classification_h handle, unsigne
                auto &result =
                                static_cast<ImageClassificationResult &>(machine_learning_native_get_result_cache(handle, TASK_NAME));
                // As of now, the result count of image classification task is always 1.
-               unsigned int result_cnt = static_cast<unsigned int>(!result.label.empty());
+               unsigned int result_cnt = static_cast<unsigned int>(!result.labels.empty());
                if (index >= result_cnt) {
                        LOGE("Invalid index(index = %u, result count = %u).", index, result_cnt);
                        return MEDIA_VISION_ERROR_INVALID_PARAMETER;
                }
 
-               *label = result.label.c_str();
+               *label = result.labels[0].c_str();
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();