};
struct ImageClassificationResult : public OutputBaseType {
- std::string label;
+ unsigned int number_of_classes {};
+ std::vector<std::string> labels;
+ std::vector<unsigned int> indices;
+ std::vector<float> confidences;
};
enum class ImageClassificationTaskType {
#include <string.h>
#include <map>
+#include <set>
#include <algorithm>
#include "MvMlException.h"
output_vec[idx] = PostProcess::sigmoid(output_vec[idx]);
}
- _result.label = _labels[max_element(output_vec.begin(), output_vec.end()) - output_vec.begin()];
- LOGI("Label = %s", _result.label.c_str());
+ set<pair<float, int> > topScore;
+ for (unsigned int idx = 0; idx < output_vec.size(); ++idx) {
+ topScore.insert({ output_vec[idx], idx });
+
+ // Remove the smallest score if the set size exceeds topNumber
+ if (topScore.size() > decodingScore->topNumber)
+ topScore.erase(topScore.begin());
+ }
+
+ _result.confidences.clear();
+ _result.indices.clear();
+ for (auto it = topScore.rbegin(); it != topScore.rend(); it++) {
+ _result.confidences.push_back(it->first);
+ _result.indices.push_back(it->second);
+ }
+
+ _result.number_of_classes = _result.indices.size();
+
+ _result.labels.clear();
+ for (unsigned int idx : _result.indices)
+ _result.labels.push_back(_labels[idx]);
return _result;
}
*frame_number = result.frame_number;
// As of now, the result count of image classification task is always 1.
// TODO. consider for Multilabel image classification model later.
- *result_cnt = static_cast<unsigned int>(!result.label.empty());
+ *result_cnt = static_cast<unsigned int>(!result.labels.empty());
} catch (const BaseException &e) {
LOGE("%s", e.what());
return e.getError();
auto &result =
static_cast<ImageClassificationResult &>(machine_learning_native_get_result_cache(handle, TASK_NAME));
// As of now, the result count of image classification task is always 1.
- unsigned int result_cnt = static_cast<unsigned int>(!result.label.empty());
+ unsigned int result_cnt = static_cast<unsigned int>(!result.labels.empty());
if (index >= result_cnt) {
LOGE("Invalid index(index = %u, result count = %u).", index, result_cnt);
return MEDIA_VISION_ERROR_INVALID_PARAMETER;
}
- *label = result.label.c_str();
+ *label = result.labels[0].c_str();
} catch (const BaseException &e) {
LOGE("%s", e.what());
return e.getError();