mv_machine_learning: code refactoring to GetClassificationResults function 18/264818/1
authorInki Dae <inki.dae@samsung.com>
Wed, 29 Sep 2021 09:37:43 +0000 (18:37 +0900)
committerInki Dae <inki.dae@samsung.com>
Thu, 30 Sep 2021 08:18:28 +0000 (17:18 +0900)
Did code refactoring to GetClassificationResults member function of Inference class
for next code refactoring.

What this patch did are,
 - drop code dependency of mPreProc from Inference class.
 - drop code duplication of getting image classification result by
   making same code to be used.
 - and code sliding.

Change-Id: I721dc3081f730fcdb75dbe91ad743f906e076ea4
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/mv_inference/inference/include/Inference.h
mv_machine_learning/mv_inference/inference/src/Inference.cpp
mv_machine_learning/mv_inference/inference/src/mv_inference_open.cpp

index 969b5a2..cad69e9 100644 (file)
@@ -293,7 +293,7 @@ namespace inference
                 * @since_tizen 5.5
                 * @return @c true on success, otherwise a negative error value
                 */
-               int GetClassficationResults(ImageClassificationResults *classificationResults);
+               int GetClassficationResults(ImageClassificationResults &classificationResults);
 
                /**
                 * @brief       Gets the ObjectDetectioResults
@@ -366,7 +366,6 @@ namespace inference
 
                Metadata mMetadata;
                PreProcess mPreProc;
-               PostProcess mPostProc;
 
        private:
                void CheckSupportedInferenceBackend();
index f5d5488..756f041 100755 (executable)
@@ -77,8 +77,7 @@ namespace inference
                        engine_config(),
                        mBackend(),
                        mMetadata(),
-                       mPreProc(),
-                       mPostProc()
+                       mPreProc()
        {
                LOGI("ENTER");
 
@@ -1094,20 +1093,24 @@ namespace inference
                return mSupportedInferenceBackend[backend];
        }
 
-       int Inference::GetClassficationResults(
-                       ImageClassificationResults *classificationResults)
+       int Inference::GetClassficationResults(ImageClassificationResults &results)
        {
                OutputMetadata& outputMeta = mMetadata.GetOutputMeta();
+               // Will contain top N results in ascending order.
+               std::vector<std::pair<float, int>> topScore;
+               auto threadHold = mConfig.mConfidenceThresHold;
+
+               results.number_of_classes = 0;
+
                if (outputMeta.IsParsed()) {
-                       std::vector<std::pair<float, int>> topScore;
-                       float value = 0.0f;
                        auto& info = outputMeta.GetScore();
-
                        std::vector<int> indexes = info.GetDimInfo().GetValidIndexAll();
+
                        if (indexes.size() != 1) {
                                LOGE("Invalid dim size. It should be 1");
                                return MEDIA_VISION_ERROR_INVALID_OPERATION;
                        }
+
                        int classes = mOutputLayerProperty.layers[info.GetName()].shape[indexes[0]];
 
                        if (!mOutputTensorBuffers.exist(info.GetName())) {
@@ -1115,8 +1118,14 @@ namespace inference
                                return MEDIA_VISION_ERROR_INVALID_OPERATION;
                        }
 
-                       mPostProc.ScoreClear(info.GetTopNumber());
+                       PostProcess postProc;
+
+                       postProc.ScoreClear(info.GetTopNumber());
+                       threadHold = info.GetThresHold();
+
                        for (int cId = 0; cId < classes; ++cId) {
+                               float value = 0.0f;
+
                                try {
                                        value = mOutputTensorBuffers.getValue<float>(info.GetName(), cId);
                                } catch (const std::exception& e) {
@@ -1129,35 +1138,18 @@ namespace inference
                                                                                        info.GetDeQuant()->GetScale(),
                                                                                        info.GetDeQuant()->GetZeroPoint());
                                }
-                               if (info.GetType() == INFERENCE_SCORE_TYPE_SIGMOID) {
+
+                               if (info.GetType() == INFERENCE_SCORE_TYPE_SIGMOID)
                                        value = PostProcess::sigmoid(value);
-                               }
 
-                               if (value < info.GetThresHold())
+                               if (value < threadHold)
                                        continue;
 
                                LOGI("id[%d]: %.3f", cId, value);
-                               mPostProc.ScorePush(value, cId);
-                       }
-                       mPostProc.ScorePop(topScore);
-
-                       ImageClassificationResults results;
-                       results.number_of_classes = 0;
-                       for (auto& value : topScore) {
-                               LOGI("score: %.3f, threshold: %.3f", value.first, info.GetThresHold());
-                               LOGI("idx:%d", value.second);
-                               LOGI("classProb: %.3f", value.first);
-
-                               results.indices.push_back(value.second);
-                               results.confidences.push_back(value.first);
-                               results.names.push_back(mUserListName[value.second]);
-                               results.number_of_classes++;
+                               postProc.ScorePush(value, cId);
                        }
 
-                       *classificationResults = results;
-                       LOGE("Inference: GetClassificationResults: %d\n",
-                               results.number_of_classes);
-
+                       postProc.ScorePop(topScore);
                } else {
                        tensor_t outputData;
 
@@ -1168,56 +1160,52 @@ namespace inference
                                return ret;
                        }
 
-                       // Will contain top N results in ascending order.
-                       std::vector<std::pair<float, int> > top_results;
+                       auto classes = outputData.dimInfo[0][1];
+                       auto *prediction = reinterpret_cast<float *>(outputData.data[0]);
+
+                       LOGI("class count: %d", classes);
+
                        std::priority_queue<std::pair<float, int>,
-                                                               std::vector<std::pair<float, int> >,
-                                                               std::greater<std::pair<float, int> > >
+                                                               std::vector<std::pair<float, int>>,
+                                                               std::greater<std::pair<float, int>>>
                                        top_result_pq;
-                       float value = 0.0f;
 
-                       int count = outputData.dimInfo[0][1];
-                       LOGI("count: %d", count);
-                       float *prediction = reinterpret_cast<float *>(outputData.data[0]);
-                       for (int i = 0; i < count; ++i) {
-                               value = prediction[i];
+                       for (int cId = 0; cId < classes; ++cId) {
+                               auto value = prediction[cId];
+
+                               if (value < threadHold)
+                                       continue;
 
                                // Only add it if it beats the threshold and has a chance at being in
                                // the top N.
-                               top_result_pq.push(std::pair<float, int>(value, i));
+                               top_result_pq.push(std::pair<float, int>(value, cId));
 
                                // If at capacity, kick the smallest value out.
-                               if (top_result_pq.size() > static_cast<size_t>(mConfig.mMaxOutputNumbers)) {
+                               if (top_result_pq.size() > static_cast<size_t>(mConfig.mMaxOutputNumbers))
                                        top_result_pq.pop();
-                               }
                        }
 
                        // Copy to output vector and reverse into descending order.
                        while (!top_result_pq.empty()) {
-                               top_results.push_back(top_result_pq.top());
+                               topScore.push_back(top_result_pq.top());
                                top_result_pq.pop();
                        }
-                       std::reverse(top_results.begin(), top_results.end());
-
-                       ImageClassificationResults results;
-                       results.number_of_classes = 0;
-                       for (auto& result : top_results) {
-                               if (result.first < mThreshold)
-                                       continue;
 
-                               LOGI("class Idx: %d, Prob: %.4f", result.second, result.first);
+                       std::reverse(topScore.begin(), topScore.end());
+               }
 
-                               results.indices.push_back(result.second);
-                               results.confidences.push_back(result.first);
-                               results.names.push_back(mUserListName[result.second]);
-                               results.number_of_classes++;
-                       }
+               for (auto& score : topScore) {
+                       LOGI("score: %.3f, threshold: %.3f", score.first, threadHold);
+                       LOGI("idx:%d", score.second);
+                       LOGI("classProb: %.3f", score.first);
 
-                       *classificationResults = results;
-                       LOGE("Inference: GetClassificationResults: %d\n",
-                               results.number_of_classes);
+                       results.indices.push_back(score.second);
+                       results.confidences.push_back(score.first);
+                       results.names.push_back(mUserListName[score.second]);
+                       results.number_of_classes++;
                }
 
+               LOGE("Inference: GetClassificationResults: %d\n", results.number_of_classes);
                return MEDIA_VISION_ERROR_NONE;
        }
 
index ed09bb0..41c62df 100644 (file)
@@ -638,33 +638,30 @@ int mv_inference_image_classify_open(
 
        ImageClassificationResults classificationResults;
 
-       ret = pInfer->GetClassficationResults(&classificationResults);
+       ret = pInfer->GetClassficationResults(classificationResults);
        if (ret != MEDIA_VISION_ERROR_NONE) {
                LOGE("Fail to get inference results");
                return ret;
        }
 
        int numberOfOutputs = classificationResults.number_of_classes;
-
        static const int START_CLASS_NUMBER = 10;
        static std::vector<const char *> names(START_CLASS_NUMBER);
 
        if (numberOfOutputs > START_CLASS_NUMBER)
                names.resize(numberOfOutputs);
 
-       LOGE("mv_inference_open: number_of_classes: %d\n",
-                classificationResults.number_of_classes);
+       LOGE("mv_inference_open: number_of_classes: %d\n", numberOfOutputs);
 
-       for (int n = 0; n < numberOfOutputs; ++n) {
-               LOGE("names: %s", classificationResults.names[n].c_str());
-               names[n] = classificationResults.names[n].c_str();
+       for (int output_index = 0; output_index < numberOfOutputs; ++output_index) {
+               LOGE("names: %s", classificationResults.names[output_index].c_str());
+               names[output_index] = classificationResults.names[output_index].c_str();
        }
 
-       int *indices = classificationResults.indices.data();
-       float *confidences = classificationResults.confidences.data();
+       auto *indices = classificationResults.indices.data();
+       auto *confidences = classificationResults.confidences.data();
 
-       classified_cb(source, numberOfOutputs, indices, names.data(), confidences,
-                                 user_data);
+       classified_cb(source, numberOfOutputs, indices, names.data(), confidences, user_data);
 
        return ret;
 }